{ "cells": [ { "cell_type": "markdown", "id": "7c4c44c8", "metadata": {}, "source": [ "## MOSAIKS feature extraction\n", "\n", "This tutorial demonstrates the **MOSAIKS** method for extracting _feature vectors_ from satellite imagery patches for use in downstream modeling tasks. It will show:\n", "- How to extract 1km$^2$ patches of Sentinel 2 multispectral imagery for a list of latitude, longitude points\n", "- How to extract summary features from each of these imagery patches\n", "- How to use the summary features in a linear model of the population density at each point\n", "\n", "### Background\n", "\n", "Consider the case where you have a dataset of latitude and longitude points associated with some dependent variable (for example: population density, weather, housing prices, biodiversity) and, potentially, other independent variables. You would like to model the dependent variable as a function of the independent variables, but instead of including latitude and longitude directly in this model, you would like to include some high dimensional representation of what the Earth looks like at that point (that hopefully explains some of the variance in the dependent variable!). From the computer vision literature, there are various [representation learning techniques](https://en.wikipedia.org/wiki/Feature_learning) that can be used to do this, i.e. extract _features vectors_ from imagery. This notebook gives an implementation of the technique described in [Rolf et al. 2021](https://www.nature.com/articles/s41467-021-24638-z), \"A generalizable and accessible approach to machine learning with global satellite imagery\" called Multi-task Observation using Satellite Imagery & Kitchen Sinks (**MOSAIKS**). For more information about **MOSAIKS** see the [project's webpage](http://www.globalpolicy.science/mosaiks).\n", "\n", "\n", "**Notes**:\n", "- This example uses [Sentinel-2 Level-2A data](https://planetarycomputer.microsoft.com/dataset/sentinel-2-l2a). The techniques used here apply equally well to other remote-sensing datasets.\n", "- If you're running this on the [Planetary Computer Hub](http://planetarycomputer.microsoft.com/compute), make sure to choose the **GPU - PyTorch** profile when presented with the form to choose your environment." ] }, { "cell_type": "code", "execution_count": 1, "id": "ab74cb2f", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "import time\n", "import os\n", "\n", "RASTERIO_BEST_PRACTICES = dict( # See https://github.com/pangeo-data/cog-best-practices\n", " CURL_CA_BUNDLE=\"/etc/ssl/certs/ca-certificates.crt\",\n", " GDAL_DISABLE_READDIR_ON_OPEN=\"EMPTY_DIR\",\n", " AWS_NO_SIGN_REQUEST=\"YES\",\n", " GDAL_MAX_RAW_BLOCK_CACHE_SIZE=\"200000000\",\n", " GDAL_SWATH_SIZE=\"200000000\",\n", " VSI_CURL_CACHE_SIZE=\"200000000\",\n", ")\n", "os.environ.update(RASTERIO_BEST_PRACTICES)\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import torch\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "from torch.utils.data import Dataset, DataLoader\n", "\n", "import rasterio\n", "import rasterio.warp\n", "import rasterio.mask\n", "import shapely.geometry\n", "import geopandas\n", "import dask_geopandas\n", "from sklearn.linear_model import RidgeCV\n", "from sklearn.model_selection import train_test_split\n", "from sklearn.metrics import r2_score\n", "from scipy.stats import spearmanr\n", "from scipy.linalg import LinAlgWarning\n", "from dask.distributed import Client\n", "\n", "\n", "warnings.filterwarnings(action=\"ignore\", category=LinAlgWarning, module=\"sklearn\")\n", "\n", "import pystac_client\n", "import planetary_computer as pc" ] }, { "cell_type": "markdown", "id": "e372135a", "metadata": {}, "source": [ "First we define the pytorch model that we will use to extract the features and a helper method. The **MOSAIKS** methodology describes several ways to do this and we use the simplest." ] }, { "cell_type": "code", "execution_count": 2, "id": "e13c154a", "metadata": {}, "outputs": [], "source": [ "def featurize(input_img, model, device):\n", " \"\"\"Helper method for running an image patch through the model.\n", "\n", " Args:\n", " input_img (np.ndarray): Image in (C x H x W) format with a dtype of uint8.\n", " model (torch.nn.Module): Feature extractor network\n", " \"\"\"\n", " assert len(input_img.shape) == 3\n", " input_img = torch.from_numpy(input_img / 255.0).float()\n", " input_img = input_img.to(device)\n", " with torch.no_grad():\n", " feats = model(input_img.unsqueeze(0)).cpu().numpy()\n", " return feats\n", "\n", "\n", "class RCF(nn.Module):\n", " \"\"\"A model for extracting Random Convolution Features (RCF) from input imagery.\"\"\"\n", "\n", " def __init__(self, num_features=16, kernel_size=3, num_input_channels=3):\n", " super(RCF, self).__init__()\n", "\n", " # We create `num_features / 2` filters so require `num_features` to be divisible by 2\n", " assert num_features % 2 == 0\n", "\n", " self.conv1 = nn.Conv2d(\n", " num_input_channels,\n", " num_features // 2,\n", " kernel_size=kernel_size,\n", " stride=1,\n", " padding=0,\n", " dilation=1,\n", " bias=True,\n", " )\n", "\n", " nn.init.normal_(self.conv1.weight, mean=0.0, std=1.0)\n", " nn.init.constant_(self.conv1.bias, -1.0)\n", "\n", " def forward(self, x):\n", " x1a = F.relu(self.conv1(x), inplace=True)\n", " x1b = F.relu(-self.conv1(x), inplace=True)\n", "\n", " x1a = F.adaptive_avg_pool2d(x1a, (1, 1)).squeeze()\n", " x1b = F.adaptive_avg_pool2d(x1b, (1, 1)).squeeze()\n", "\n", " if len(x1a.shape) == 1: # case where we passed a single input\n", " return torch.cat((x1a, x1b), dim=0)\n", " elif len(x1a.shape) == 2: # case where we passed a batch of > 1 inputs\n", " return torch.cat((x1a, x1b), dim=1)" ] }, { "cell_type": "markdown", "id": "42f303c0", "metadata": {}, "source": [ "Next, we initialize the model and pytorch components" ] }, { "cell_type": "code", "execution_count": 3, "id": "60b896be", "metadata": {}, "outputs": [], "source": [ "num_features = 1024\n", "\n", "device = torch.device(\"cuda\")\n", "model = RCF(num_features).eval().to(device)" ] }, { "cell_type": "markdown", "id": "6c804888", "metadata": { "tags": [] }, "source": [ "### Read dataset of (lat, lon) points and corresponding labels" ] }, { "cell_type": "markdown", "id": "6462a61e", "metadata": {}, "source": [ "We read a CSV of 100,000 randomly sampled (lat, lon) points over the US and the corresponding population living roughly within 1km$^2$ of the points from the [Gridded Population of the World](https://sedac.ciesin.columbia.edu/downloads/data/gpw-v4/gpw-v4-population-density-rev10/gpw-v4-population-density-rev10_2015_30_sec_tif.zip) dataset. This data comes from the [Code Ocean capsule](https://codeocean.com/capsule/6456296/tree/v2) that accompanies the Rolf et al. 2021 paper." ] }, { "cell_type": "code", "execution_count": 4, "id": "92ef8030", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
V1V1.1IDlonlatpopulationgeometry
1111225,1595-103.04673537.9323140.085924POINT (-103.04673 37.93231)
2221521,2455-91.20243834.6475790.808222POINT (-91.20244 34.64758)
444828,3849-72.00366042.116711101.286320POINT (-72.00366 42.11671)
5551530,2831-86.02400234.54555228.181724POINT (-86.02400 34.54555)
6661097,2696-87.88328139.30945511.923701POINT (-87.88328 39.30945)
........................
9999599995999951164,1275-107.45391538.5919060.006152POINT (-107.45391 38.59191)
999969999699996343,2466-91.05094146.8768061.782686POINT (-91.05094 46.87681)
999979999799997634,3922-70.99827244.0674308.383357POINT (-70.99827 44.06743)
999989999899998357,2280-93.61261546.7448548.110552POINT (-93.61261 46.74485)
1000001000001000001449,2691-87.95214335.4592631.972914POINT (-87.95214 35.45926)
\n", "

67968 rows × 7 columns

\n", "
" ], "text/plain": [ " V1 V1.1 ID lon lat population \\\n", "1 1 1 1225,1595 -103.046735 37.932314 0.085924 \n", "2 2 2 1521,2455 -91.202438 34.647579 0.808222 \n", "4 4 4 828,3849 -72.003660 42.116711 101.286320 \n", "5 5 5 1530,2831 -86.024002 34.545552 28.181724 \n", "6 6 6 1097,2696 -87.883281 39.309455 11.923701 \n", "... ... ... ... ... ... ... \n", "99995 99995 99995 1164,1275 -107.453915 38.591906 0.006152 \n", "99996 99996 99996 343,2466 -91.050941 46.876806 1.782686 \n", "99997 99997 99997 634,3922 -70.998272 44.067430 8.383357 \n", "99998 99998 99998 357,2280 -93.612615 46.744854 8.110552 \n", "100000 100000 100000 1449,2691 -87.952143 35.459263 1.972914 \n", "\n", " geometry \n", "1 POINT (-103.04673 37.93231) \n", "2 POINT (-91.20244 34.64758) \n", "4 POINT (-72.00366 42.11671) \n", "5 POINT (-86.02400 34.54555) \n", "6 POINT (-87.88328 39.30945) \n", "... ... \n", "99995 POINT (-107.45391 38.59191) \n", "99996 POINT (-91.05094 46.87681) \n", "99997 POINT (-70.99827 44.06743) \n", "99998 POINT (-93.61261 46.74485) \n", "100000 POINT (-87.95214 35.45926) \n", "\n", "[67968 rows x 7 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df = pd.read_csv(\n", " \"https://files.codeocean.com/files/verified/fa908bbc-11f9-4421-8bd3-72a4bf00427f_v2.0/data/int/applications/population/outcomes_sampled_population_CONTUS_16_640_UAR_100000_0.csv?download\", # noqa: E501\n", " index_col=0,\n", " na_values=[-999],\n", ").dropna()\n", "points = df[[\"lon\", \"lat\"]]\n", "population = df[\"population\"]\n", "\n", "gdf = geopandas.GeoDataFrame(df, geometry=geopandas.points_from_xy(df.lon, df.lat))\n", "gdf" ] }, { "cell_type": "markdown", "id": "cf320a80", "metadata": {}, "source": [ "Get rid of points with nodata population values" ] }, { "cell_type": "code", "execution_count": 5, "id": "eba76b98", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "population.plot.hist();" ] }, { "cell_type": "markdown", "id": "4120a78a", "metadata": {}, "source": [ "Population is lognormal distributed, so transforming it to log-space makes sense for modeling purposes" ] }, { "cell_type": "code", "execution_count": 6, "id": "a4877e7f", "metadata": {}, "outputs": [], "source": [ "population_log = np.log10(population + 1)" ] }, { "cell_type": "code", "execution_count": 7, "id": "c6b4f9d5", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "population_log.plot.hist();" ] }, { "cell_type": "code", "execution_count": 8, "id": "605bc1e7", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "ax = points.assign(population=population_log).plot.scatter(\n", " x=\"lon\",\n", " y=\"lat\",\n", " c=\"population\",\n", " s=1,\n", " cmap=\"viridis\",\n", " figsize=(10, 6),\n", " colorbar=False,\n", ")\n", "ax.set_axis_off();" ] }, { "cell_type": "markdown", "id": "83e52c4e", "metadata": {}, "source": [ "### Extract features from the imagery around each point\n", "\n", "We need to find a suitable Sentinel 2 scene for each point. As usual, we'll use `pystac-client` to search for items matching some conditions, but we don't just want do make a `.search()` call for each of the 67,968 remaining points. Each HTTP request is relatively slow. Instead, we will *batch* or points and search *in parallel*.\n", "\n", "We need to be a bit careful with how we batch up our points though. Since a single Sentinel 2 scene will cover many points, we want to make sure that points which are spatially close together end up in the same batch. In short, we need to spatially partition the dataset. This is implemented in `dask-geopandas`.\n", "\n", "So the overall workflow will be\n", "\n", "1. Find an appropriate STAC item for each point (in parallel, using the spatially partitioned dataset)\n", "2. Feed the points and STAC items to a custom Dataset that can read imagery given a point and the URL of a overlapping S2 scene\n", "3. Use a custom Dataloader, which uses our Dataset, to feed our model imagery and save the corresponding features" ] }, { "cell_type": "code", "execution_count": 9, "id": "d9d7aefb", "metadata": {}, "outputs": [], "source": [ "NPARTITIONS = 250\n", "\n", "ddf = dask_geopandas.from_geopandas(gdf, npartitions=1)\n", "hd = ddf.hilbert_distance().compute()\n", "gdf[\"hd\"] = hd\n", "gdf = gdf.sort_values(\"hd\")\n", "\n", "dgdf = dask_geopandas.from_geopandas(gdf, npartitions=NPARTITIONS, sort=False)" ] }, { "cell_type": "markdown", "id": "3b6ae053", "metadata": {}, "source": [ "We'll write a helper function that " ] }, { "cell_type": "code", "execution_count": 10, "id": "e55349b5", "metadata": {}, "outputs": [], "source": [ "def query(points):\n", " \"\"\"\n", " Find a STAC item for points in the `points` DataFrame\n", "\n", " Parameters\n", " ----------\n", " points : geopandas.GeoDataFrame\n", " A GeoDataFrame\n", "\n", " Returns\n", " -------\n", " geopandas.GeoDataFrame\n", " A new geopandas.GeoDataFrame with a `stac_item` column containing the STAC\n", " item that covers each point.\n", " \"\"\"\n", " intersects = shapely.geometry.mapping(points.unary_union.convex_hull)\n", "\n", " search_start = \"2018-01-01\"\n", " search_end = \"2019-12-31\"\n", " catalog = pystac_client.Client.open(\n", " \"https://planetarycomputer.microsoft.com/api/stac/v1\"\n", " )\n", "\n", " # The time frame in which we search for non-cloudy imagery\n", " search = catalog.search(\n", " collections=[\"sentinel-2-l2a\"],\n", " intersects=intersects,\n", " datetime=[search_start, search_end],\n", " query={\"eo:cloud_cover\": {\"lt\": 10}},\n", " limit=500,\n", " )\n", " ic = search.get_all_items_as_dict()\n", "\n", " features = ic[\"features\"]\n", " features_d = {item[\"id\"]: item for item in features}\n", "\n", " data = {\n", " \"eo:cloud_cover\": [],\n", " \"geometry\": [],\n", " }\n", "\n", " index = []\n", "\n", " for item in features:\n", " data[\"eo:cloud_cover\"].append(item[\"properties\"][\"eo:cloud_cover\"])\n", " data[\"geometry\"].append(shapely.geometry.shape(item[\"geometry\"]))\n", " index.append(item[\"id\"])\n", "\n", " items = geopandas.GeoDataFrame(data, index=index, geometry=\"geometry\").sort_values(\n", " \"eo:cloud_cover\"\n", " )\n", " point_list = points.geometry.tolist()\n", "\n", " point_items = []\n", " for point in point_list:\n", " covered_by = items[items.covers(point)]\n", " if len(covered_by):\n", " point_items.append(features_d[covered_by.index[0]])\n", " else:\n", " # There weren't any scenes matching our conditions for this point (too cloudy)\n", " point_items.append(None)\n", "\n", " return points.assign(stac_item=point_items)" ] }, { "cell_type": "code", "execution_count": 11, "id": "5350c952", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "2022-09-20 14:39:44,577 - distributed.diskutils - INFO - Found stale lock file and directory '/home/jovyan/PlanetaryComputerExamples/tutorials/dask-worker-space/worker-rz97cdcq', purging\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "/user/taugspurger@microsoft.com/proxy/8787/status\n", "CPU times: user 8.84 s, sys: 5.28 s, total: 14.1 s\n", "Wall time: 1min 27s\n" ] } ], "source": [ "%%time\n", "\n", "with Client(n_workers=16) as client:\n", " print(client.dashboard_link)\n", " meta = dgdf._meta.assign(stac_item=[])\n", " df2 = dgdf.map_partitions(query, meta=meta).compute()" ] }, { "cell_type": "code", "execution_count": 12, "id": "9d9ea985", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
V1V1.1IDlonlatpopulationgeometryhdstac_item
5822758227582272253,2012-97.30362825.9581630.193628POINT (-97.30363 25.95816)360793917{'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP...
2799327993279932245,1999-97.48267026.057180244.428020POINT (-97.48267 26.05718)360931126{'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP...
6180861808618082246,1986-97.66171226.044807525.789124POINT (-97.66171 26.04481)360950765{'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP...
9771097710977102240,1992-97.57907726.11902370.362312POINT (-97.57908 26.11902)360966149{'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP...
5544255442554422231,1994-97.55153226.23025844.150997POINT (-97.55153 26.23026)361043545{'id': 'S2B_MSIL2A_20181213T170709_R069_T14RPQ...
\n", "
" ], "text/plain": [ " V1 V1.1 ID lon lat population \\\n", "58227 58227 58227 2253,2012 -97.303628 25.958163 0.193628 \n", "27993 27993 27993 2245,1999 -97.482670 26.057180 244.428020 \n", "61808 61808 61808 2246,1986 -97.661712 26.044807 525.789124 \n", "97710 97710 97710 2240,1992 -97.579077 26.119023 70.362312 \n", "55442 55442 55442 2231,1994 -97.551532 26.230258 44.150997 \n", "\n", " geometry hd \\\n", "58227 POINT (-97.30363 25.95816) 360793917 \n", "27993 POINT (-97.48267 26.05718) 360931126 \n", "61808 POINT (-97.66171 26.04481) 360950765 \n", "97710 POINT (-97.57908 26.11902) 360966149 \n", "55442 POINT (-97.55153 26.23026) 361043545 \n", "\n", " stac_item \n", "58227 {'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP... \n", "27993 {'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP... \n", "61808 {'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP... \n", "97710 {'id': 'S2A_MSIL2A_20191213T170711_R069_T14RPP... \n", "55442 {'id': 'S2B_MSIL2A_20181213T170709_R069_T14RPQ... " ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "df2.head()" ] }, { "cell_type": "code", "execution_count": 13, "id": "14eab13c", "metadata": {}, "outputs": [], "source": [ "df3 = df2.dropna(subset=[\"stac_item\"])\n", "\n", "matching_urls = [\n", " pc.sign(item[\"assets\"][\"visual\"][\"href\"]) for item in df3.stac_item.tolist()\n", "]\n", "\n", "points = df3[[\"lon\", \"lat\"]].to_numpy()\n", "population_log = np.log10(df3[\"population\"].to_numpy() + 1)" ] }, { "cell_type": "code", "execution_count": 14, "id": "26646f38", "metadata": {}, "outputs": [], "source": [ "class CustomDataset(Dataset):\n", " def __init__(self, points, fns, buffer=500):\n", " self.points = points\n", " self.fns = fns\n", " self.buffer = buffer\n", "\n", " def __len__(self):\n", " return self.points.shape[0]\n", "\n", " def __getitem__(self, idx):\n", "\n", " lon, lat = self.points[idx]\n", " fn = self.fns[idx]\n", "\n", " if fn is None:\n", " return None\n", " else:\n", " point_geom = shapely.geometry.mapping(shapely.geometry.Point(lon, lat))\n", "\n", " with rasterio.Env():\n", " with rasterio.open(fn, \"r\") as f:\n", " point_geom = rasterio.warp.transform_geom(\n", " \"epsg:4326\", f.crs.to_string(), point_geom\n", " )\n", " point_shape = shapely.geometry.shape(point_geom)\n", " mask_shape = point_shape.buffer(self.buffer).envelope\n", " mask_geom = shapely.geometry.mapping(mask_shape)\n", " try:\n", " out_image, out_transform = rasterio.mask.mask(\n", " f, [mask_geom], crop=True\n", " )\n", " except ValueError as e:\n", " if \"Input shapes do not overlap raster.\" in str(e):\n", " return None\n", "\n", " out_image = out_image / 255.0\n", " out_image = torch.from_numpy(out_image).float()\n", " return out_image" ] }, { "cell_type": "code", "execution_count": 15, "id": "a696ef33", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/srv/conda/envs/notebook/lib/python3.9/site-packages/torch/utils/data/dataloader.py:563: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.\n", " warnings.warn(_create_warning_msg(\n" ] } ], "source": [ "dataset = CustomDataset(points, matching_urls)\n", "\n", "dataloader = DataLoader(\n", " dataset,\n", " batch_size=8,\n", " shuffle=False,\n", " num_workers=os.cpu_count() * 2,\n", " collate_fn=lambda x: x,\n", " pin_memory=False,\n", ")" ] }, { "cell_type": "code", "execution_count": 16, "id": "b1e1e002", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0/67922 -- 0.00% -- 2.14 seconds\n", "1000/67922 -- 1.47% -- 9.83 seconds\n", "2000/67922 -- 2.94% -- 10.77 seconds\n", "3000/67922 -- 4.42% -- 8.54 seconds\n", "4000/67922 -- 5.89% -- 9.70 seconds\n", "5000/67922 -- 7.36% -- 10.22 seconds\n", "6000/67922 -- 8.83% -- 9.24 seconds\n", "7000/67922 -- 10.31% -- 8.40 seconds\n", "8000/67922 -- 11.78% -- 10.03 seconds\n", "9000/67922 -- 13.25% -- 10.12 seconds\n", "10000/67922 -- 14.72% -- 7.72 seconds\n", "11000/67922 -- 16.20% -- 8.70 seconds\n", "12000/67922 -- 17.67% -- 9.44 seconds\n", "13000/67922 -- 19.14% -- 9.63 seconds\n", "14000/67922 -- 20.61% -- 8.08 seconds\n", "15000/67922 -- 22.08% -- 6.65 seconds\n", "16000/67922 -- 23.56% -- 7.16 seconds\n", "17000/67922 -- 25.03% -- 9.23 seconds\n", "18000/67922 -- 26.50% -- 8.26 seconds\n", "19000/67922 -- 27.97% -- 7.38 seconds\n", "20000/67922 -- 29.45% -- 9.18 seconds\n", "21000/67922 -- 30.92% -- 8.15 seconds\n", "22000/67922 -- 32.39% -- 6.03 seconds\n", "23000/67922 -- 33.86% -- 6.42 seconds\n", "24000/67922 -- 35.33% -- 7.21 seconds\n", "25000/67922 -- 36.81% -- 4.22 seconds\n", "26000/67922 -- 38.28% -- 5.08 seconds\n", "27000/67922 -- 39.75% -- 5.04 seconds\n", "28000/67922 -- 41.22% -- 6.26 seconds\n", "29000/67922 -- 42.70% -- 7.31 seconds\n", "30000/67922 -- 44.17% -- 9.81 seconds\n", "31000/67922 -- 45.64% -- 9.71 seconds\n", "32000/67922 -- 47.11% -- 8.95 seconds\n", "33000/67922 -- 48.59% -- 8.53 seconds\n", "34000/67922 -- 50.06% -- 8.38 seconds\n", "35000/67922 -- 51.53% -- 9.55 seconds\n", "36000/67922 -- 53.00% -- 8.56 seconds\n", "37000/67922 -- 54.47% -- 7.19 seconds\n", "38000/67922 -- 55.95% -- 7.69 seconds\n", "39000/67922 -- 57.42% -- 6.36 seconds\n", "40000/67922 -- 58.89% -- 6.36 seconds\n", "41000/67922 -- 60.36% -- 6.68 seconds\n", "42000/67922 -- 61.84% -- 6.39 seconds\n", "43000/67922 -- 63.31% -- 3.95 seconds\n", "44000/67922 -- 64.78% -- 6.06 seconds\n", "45000/67922 -- 66.25% -- 5.87 seconds\n", "46000/67922 -- 67.72% -- 5.54 seconds\n", "47000/67922 -- 69.20% -- 9.00 seconds\n", "48000/67922 -- 70.67% -- 7.76 seconds\n", "49000/67922 -- 72.14% -- 7.21 seconds\n", "50000/67922 -- 73.61% -- 7.46 seconds\n", "51000/67922 -- 75.09% -- 6.69 seconds\n", "52000/67922 -- 76.56% -- 7.35 seconds\n", "53000/67922 -- 78.03% -- 7.03 seconds\n", "54000/67922 -- 79.50% -- 7.80 seconds\n", "55000/67922 -- 80.98% -- 6.92 seconds\n", "56000/67922 -- 82.45% -- 7.41 seconds\n", "57000/67922 -- 83.92% -- 6.12 seconds\n", "58000/67922 -- 85.39% -- 6.56 seconds\n", "59000/67922 -- 86.86% -- 6.69 seconds\n", "60000/67922 -- 88.34% -- 8.39 seconds\n", "61000/67922 -- 89.81% -- 10.02 seconds\n", "62000/67922 -- 91.28% -- 8.55 seconds\n", "63000/67922 -- 92.75% -- 8.21 seconds\n", "64000/67922 -- 94.23% -- 7.93 seconds\n", "65000/67922 -- 95.70% -- 6.39 seconds\n", "66000/67922 -- 97.17% -- 8.39 seconds\n", "67000/67922 -- 98.64% -- 8.02 seconds\n" ] } ], "source": [ "x_all = np.zeros((points.shape[0], num_features), dtype=float)\n", "\n", "tic = time.time()\n", "i = 0\n", "for images in dataloader:\n", " for image in images:\n", "\n", " if image is not None:\n", " # A full image should be ~101x101 pixels (i.e. ~1km^2 at a 10m/px spatial\n", " # resolution), however we can receive smaller images if an input point\n", " # happens to be at the edge of a S2 scene (a literal edge case). To deal\n", " # with these (edge) cases we crudely drop all images where the spatial\n", " # dimensions aren't both greater than 20 pixels.\n", " if image.shape[1] >= 20 and image.shape[2] >= 20:\n", " image = image.to(device)\n", " with torch.no_grad():\n", " feats = model(image.unsqueeze(0)).cpu().numpy()\n", "\n", " x_all[i] = feats\n", " else:\n", " # this happens if the point is close to the edge of a scene\n", " # (one or both of the spatial dimensions of the image are very small)\n", " pass\n", " else:\n", " pass # this happens if we do not find a S2 scene for some point\n", "\n", " if i % 1000 == 0:\n", " print(\n", " f\"{i}/{points.shape[0]} -- {i / points.shape[0] * 100:0.2f}%\"\n", " + f\" -- {time.time()-tic:0.2f} seconds\"\n", " )\n", " tic = time.time()\n", " i += 1" ] }, { "cell_type": "markdown", "id": "5a8cce22", "metadata": { "tags": [] }, "source": [ "### Use the extracted features and given labels to model population density as a function of imagery\n", "\n", "We split the available data 80/20 into train/test. We use a cross-validation approach to tune the regularization parameter of a Ridge regression model, then apply the model to the test data and measure the R2." ] }, { "cell_type": "code", "execution_count": 17, "id": "be27b418", "metadata": {}, "outputs": [], "source": [ "y_all = population_log.copy()" ] }, { "cell_type": "code", "execution_count": 18, "id": "099e3a88", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((67922, 1024), (67922,))" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_all.shape, y_all.shape" ] }, { "cell_type": "markdown", "id": "f86a3dd6", "metadata": {}, "source": [ "And one final masking -- any sample that has all zeros for features means that we were unsuccessful at extracting features for that point." ] }, { "cell_type": "code", "execution_count": 19, "id": "5a28fad0", "metadata": {}, "outputs": [], "source": [ "nofeature_mask = ~(x_all.sum(axis=1) == 0)" ] }, { "cell_type": "code", "execution_count": 20, "id": "84ee8ae2", "metadata": {}, "outputs": [], "source": [ "x_all = x_all[nofeature_mask]\n", "y_all = y_all[nofeature_mask]" ] }, { "cell_type": "code", "execution_count": 21, "id": "e8520e69", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((67922, 1024), (67922,))" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "x_all.shape, y_all.shape" ] }, { "cell_type": "code", "execution_count": 22, "id": "0580dff8", "metadata": {}, "outputs": [], "source": [ "x_train, x_test, y_train, y_test = train_test_split(\n", " x_all, y_all, test_size=0.2, random_state=0\n", ")" ] }, { "cell_type": "code", "execution_count": 23, "id": "5f3f682c", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RidgeCV(alphas=array([1.e-08, 1.e-07, 1.e-06, 1.e-05, 1.e-04, 1.e-03, 1.e-02, 1.e-01,\n",
       "       1.e+00, 1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07,\n",
       "       1.e+08]),\n",
       "        cv=5)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RidgeCV(alphas=array([1.e-08, 1.e-07, 1.e-06, 1.e-05, 1.e-04, 1.e-03, 1.e-02, 1.e-01,\n", " 1.e+00, 1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07,\n", " 1.e+08]),\n", " cv=5)" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ridge_cv_random = RidgeCV(cv=5, alphas=np.logspace(-8, 8, base=10, num=17))\n", "ridge_cv_random.fit(x_train, y_train)" ] }, { "cell_type": "code", "execution_count": 24, "id": "3f417356", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation R2 performance 0.55\n" ] } ], "source": [ "print(f\"Validation R2 performance {ridge_cv_random.best_score_:0.2f}\")" ] }, { "cell_type": "code", "execution_count": 25, "id": "6f8acfa6", "metadata": { "tags": [] }, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_pred = np.maximum(ridge_cv_random.predict(x_test), 0)\n", "\n", "plt.figure()\n", "plt.scatter(y_pred, y_test, alpha=0.2, s=4)\n", "plt.xlabel(\"Predicted\", fontsize=15)\n", "plt.ylabel(\"Ground Truth\", fontsize=15)\n", "plt.title(r\"$\\log_{10}(1 + $people$/$km$^2)$\", fontsize=15)\n", "plt.xticks(fontsize=14)\n", "plt.yticks(fontsize=14)\n", "\n", "plt.xlim([0, 6])\n", "plt.ylim([0, 6])\n", "\n", "plt.text(\n", " 0.5,\n", " 5,\n", " s=\"R$^2$ = %0.2f\" % (r2_score(y_test, y_pred)),\n", " fontsize=15,\n", " fontweight=\"bold\",\n", ")\n", "m, b = np.polyfit(y_pred, y_test, 1)\n", "plt.plot(y_pred, m * y_pred + b, color=\"black\")\n", "plt.gca().spines.right.set_visible(False)\n", "plt.gca().spines.top.set_visible(False)\n", "\n", "plt.show()\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "fb17f81e", "metadata": {}, "source": [ "In addition to a R$^2$ value of ~0.55 on the test points, we can see that we have a rank-order correlation (spearman's r) of 0.66." ] }, { "cell_type": "code", "execution_count": 26, "id": "23655d99", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SpearmanrResult(correlation=0.6683500144093596, pvalue=0.0)" ] }, "execution_count": 26, "metadata": {}, "output_type": "execute_result" } ], "source": [ "spearmanr(y_pred, y_test)" ] }, { "cell_type": "markdown", "id": "e495f7b6", "metadata": {}, "source": [ "### Spatial extrapolation\n", "\n", "In the previous section we split the points randomly and found that our model can _interpolate_ population density with an R2 of 0.55, however this result does not say anything about how well the model will extrapolate. Whenever you are modeling spatio-temporal data it is important to consider what the model is doing as well as the purpose of the model, then evaluate it appropriately. Here, we test how our modeling approach above is able to extrapolate to areas that it has not been trained on. Specifically we train the linear model with data from the _western_ portion of the US then test it on data from the _eastern_ US and interpret the results." ] }, { "cell_type": "code", "execution_count": 27, "id": "5a03a43b", "metadata": {}, "outputs": [], "source": [ "points = points[nofeature_mask]" ] }, { "cell_type": "markdown", "id": "ef3c8a6e", "metadata": {}, "source": [ "First we calculate the 80th percentile longitude of the points in our dataset. Points that are to the west of this value will be in our training split and points to the east of this will be in our testing split." ] }, { "cell_type": "code", "execution_count": 28, "id": "78ee7a41", "metadata": {}, "outputs": [], "source": [ "split_lon = np.percentile(points[:, 0], 80)\n", "train_idxs = np.where(points[:, 0] <= split_lon)[0]\n", "test_idxs = np.where(points[:, 0] > split_lon)[0]\n", "\n", "x_train = x_all[train_idxs]\n", "x_test = x_all[test_idxs]\n", "y_train = y_all[train_idxs]\n", "y_test = y_all[test_idxs]" ] }, { "cell_type": "markdown", "id": "01ac0a57", "metadata": {}, "source": [ "Visually, the split looks like this:" ] }, { "cell_type": "code", "execution_count": 29, "id": "97e4da02", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.figure()\n", "plt.scatter(points[:, 0], points[:, 1], c=y_all, s=1)\n", "plt.vlines(\n", " split_lon,\n", " ymin=points[:, 1].min(),\n", " ymax=points[:, 1].max(),\n", " color=\"black\",\n", " linewidth=4,\n", ")\n", "plt.axis(\"off\")\n", "plt.show()\n", "plt.close()" ] }, { "cell_type": "code", "execution_count": 30, "id": "e1cd8fb2", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
RidgeCV(alphas=array([1.e-08, 1.e-07, 1.e-06, 1.e-05, 1.e-04, 1.e-03, 1.e-02, 1.e-01,\n",
       "       1.e+00, 1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07,\n",
       "       1.e+08]),\n",
       "        cv=5)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
" ], "text/plain": [ "RidgeCV(alphas=array([1.e-08, 1.e-07, 1.e-06, 1.e-05, 1.e-04, 1.e-03, 1.e-02, 1.e-01,\n", " 1.e+00, 1.e+01, 1.e+02, 1.e+03, 1.e+04, 1.e+05, 1.e+06, 1.e+07,\n", " 1.e+08]),\n", " cv=5)" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "ridge_cv = RidgeCV(cv=5, alphas=np.logspace(-8, 8, base=10, num=17))\n", "ridge_cv.fit(x_train, y_train)" ] }, { "cell_type": "markdown", "id": "57cd5663", "metadata": {}, "source": [ "We can see that our validation performance is similar to that of the random split:" ] }, { "cell_type": "code", "execution_count": 31, "id": "158f75c0", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Validation R2 performance 0.14\n" ] } ], "source": [ "print(f\"Validation R2 performance {ridge_cv.best_score_:0.2f}\")" ] }, { "cell_type": "markdown", "id": "cdbe52f0", "metadata": {}, "source": [ "However, our _test_ R$^2$ is much lower, 0.13 compared to 0.55. This shows that the linear model trained on **MOSAIKS** features and population data sampled from the _western_ US is not able to predict the population density in the _eastern_ US as well. However, from the scatter plot we can see that the predictions aren't random which warrants further investigation..." ] }, { "cell_type": "code", "execution_count": 32, "id": "b1fe36e4", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "y_pred = np.maximum(ridge_cv.predict(x_test), 0)\n", "\n", "plt.figure()\n", "plt.scatter(y_pred, y_test, alpha=0.2, s=4)\n", "plt.xlabel(\"Predicted\", fontsize=15)\n", "plt.ylabel(\"Ground Truth\", fontsize=15)\n", "plt.title(r\"$\\log_{10}(1 + $people$/$km$^2)$\", fontsize=15)\n", "plt.xticks(fontsize=14)\n", "plt.yticks(fontsize=14)\n", "\n", "plt.xlim([0, 6])\n", "plt.ylim([0, 6])\n", "\n", "plt.text(\n", " 0.5,\n", " 5,\n", " s=\"R$^2$ = %0.2f\" % (r2_score(y_test, y_pred)),\n", " fontsize=15,\n", " fontweight=\"bold\",\n", ")\n", "m, b = np.polyfit(y_pred, y_test, 1)\n", "plt.plot(y_pred, m * y_pred + b, color=\"black\")\n", "plt.gca().spines.right.set_visible(False)\n", "plt.gca().spines.top.set_visible(False)\n", "\n", "plt.show()\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "430151e7", "metadata": {}, "source": [ "The rank-order correlation is still high, 0.61 compared to 0.66 from the random split. This shows that the model is still able to correctly _order_ the density of input points, however is wrong about the magnitude of the population densities." ] }, { "cell_type": "code", "execution_count": 33, "id": "3dc6aa83", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "SpearmanrResult(correlation=0.615843938172293, pvalue=0.0)" ] }, "execution_count": 33, "metadata": {}, "output_type": "execute_result" } ], "source": [ "spearmanr(y_test, y_pred)" ] }, { "cell_type": "markdown", "id": "5e0e2e0b", "metadata": {}, "source": [ "This makes sense when we compare the distributions of population density of points from the western US to that of the eastern US -- the label distributions are completely different. The distribution of **MOSAIKS** features likely doesn't change, however the mapping between these features and population density definitely varies with space." ] }, { "cell_type": "code", "execution_count": 34, "id": "0bc15f54", "metadata": {}, "outputs": [ { "data": { "text/html": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "text/html": [ "" ], "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "bins = np.linspace(0, 5, num=50)\n", "\n", "plt.figure()\n", "plt.hist(y_train, bins=bins)\n", "plt.ylabel(\"Frequency\")\n", "plt.xlabel(r\"$\\log_{10}(1 + $people$/$km$^2)$\")\n", "plt.title(\"Train points -- western US\")\n", "plt.gca().spines.right.set_visible(False)\n", "plt.gca().spines.top.set_visible(False)\n", "plt.show()\n", "plt.close()\n", "\n", "plt.figure()\n", "plt.hist(y_test, bins=bins)\n", "plt.ylabel(\"Frequency\")\n", "plt.xlabel(r\"$\\log_{10}(1 + $people$/$km$^2)$\")\n", "plt.title(\"Test points -- eastern US\")\n", "plt.gca().spines.right.set_visible(False)\n", "plt.gca().spines.top.set_visible(False)\n", "plt.show()\n", "plt.close()" ] }, { "cell_type": "markdown", "id": "67397753", "metadata": {}, "source": [ "Estimating the _relative_ population density of points from the corresponding imagery is still useful in a wide variety of applications, e.g. in disaster response settings it might make sense to allocate the most resources to the most densely populated locations, where the precise density isn't as important." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "widgets": { "application/vnd.jupyter.widget-state+json": { "state": {}, "version_major": 2, "version_minor": 0 } } }, "nbformat": 4, "nbformat_minor": 5 }