{ "cells": [ { "cell_type": "markdown", "id": "8dd35554-a9dd-49cf-b9fa-24fa8ae6cecf", "metadata": {}, "source": [ "# Monsoon flood analysis using embeddings from partial inputs\n", "This notebook contains a complete example for how to run Clay. It\n", "combines the following three different aspects:\n", "\n", "1. Create single-chip datacubes with time series data for a location and a date range\n", "2. Run the model with partial inputs, in this case RGB + NIR + SWIR\n", "3. Study flood extent through the embeddings generated for that datacube\n", "\n", "## Let's start with importing and creating constants" ] }, { "cell_type": "code", "execution_count": null, "id": "b7bcff1e-bdb5-47f8-aa0e-d68d6fdd3476", "metadata": {}, "outputs": [], "source": [ "# Ensure working directory is the repo home\n", "import os\n", "\n", "os.chdir(\"..\")" ] }, { "cell_type": "code", "execution_count": null, "id": "15d65ec9-86aa-4275-89ba-ec79fdbad361", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "from pathlib import Path\n", "\n", "import geopandas as gpd\n", "import matplotlib.pyplot as plt\n", "import numpy\n", "import pandas as pd\n", "import pystac_client\n", "import rasterio\n", "import rioxarray # noqa: F401\n", "import stackstac\n", "import torch\n", "from rasterio.enums import Resampling\n", "from shapely import Point\n", "from sklearn import decomposition\n", "\n", "from src.datamodule import ClayDataModule\n", "from src.model_clay import CLAYModule\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "BAND_GROUPS = {\n", " \"rgb\": [\"red\", \"green\", \"blue\"],\n", " \"rededge\": [\"rededge1\", \"rededge2\", \"rededge3\", \"nir08\"],\n", " \"nir\": [\n", " \"nir\",\n", " ],\n", " \"swir\": [\"swir16\", \"swir22\"],\n", " \"sar\": [\"vv\", \"vh\"],\n", "}\n", "\n", "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", "COLLECTION = \"sentinel-2-l2a\"" ] }, { "cell_type": "markdown", "id": "a6341305-9c44-4a1e-847c-80d77b01c0bf", "metadata": {}, "source": [ "## Search for imagery over an area of interest\n", "In this example we use a location and date range to visualize a major monsoon flood that happened in [Padidan, Pakistan in 2022](https://floodlist.com/asia/pakistan-monsoon-floods-august-2022)." ] }, { "cell_type": "code", "execution_count": null, "id": "a1886f5a-8669-40e7-8fae-e45619570e3c", "metadata": {}, "outputs": [], "source": [ "# Point over Padidan, Pakistan\n", "poi = 26.776567, 68.287374\n", "\n", "# Dates of a major monsoon flood (August 20, 2022)\n", "start = \"2022-06-01\"\n", "end = \"2022-09-30\"\n", "\n", "catalog = pystac_client.Client.open(STAC_API)\n", "\n", "search = catalog.search(\n", " collections=[COLLECTION],\n", " datetime=f\"{start}/{end}\",\n", " bbox=(poi[1] - 1e-5, poi[0] - 1e-5, poi[1] + 1e-5, poi[0] + 1e-5),\n", " max_items=100,\n", " query={\"eo:cloud_cover\": {\"lt\": 50}},\n", ")\n", "\n", "items = search.get_all_items()\n", "\n", "print(f\"Found {len(items)} items\")" ] }, { "cell_type": "markdown", "id": "c4ba5c36-90a6-427c-80c5-2a83ad11a1b0", "metadata": {}, "source": [ "## Download the data\n", "Get the data into a numpy array and visualize the imagery. The flood is visible in the last seven images. Note: SWIR is very helpful for flood mapping (ref: [Satellite flood detection integrating hydrogeomorphic and spectral indices](https://www.tandfonline.com/doi/full/10.1080/15481603.2022.2143670#:~:text=Methods%20used%20to%20detect%20water,Near%20Infrared%20(NIR)%20wavelengths.))." ] }, { "cell_type": "code", "execution_count": null, "id": "c371501c-3ef0-4507-9073-0521a1c733be", "metadata": {}, "outputs": [], "source": [ "# Extract coordinate system from first item\n", "epsg = items[0].properties[\"proj:epsg\"]\n", "\n", "# Convert point into the image projection\n", "poidf = gpd.GeoDataFrame(\n", " pd.DataFrame(),\n", " crs=\"EPSG:4326\",\n", " geometry=[Point(poi[1], poi[0])],\n", ").to_crs(epsg)\n", "\n", "coords = poidf.iloc[0].geometry.coords[0]\n", "\n", "# Create bounds of the correct size, the model\n", "# requires 512x512 pixels at 10m resolution.\n", "bounds = (\n", " coords[0] - 2560,\n", " coords[1] - 2560,\n", " coords[0] + 2560,\n", " coords[1] + 2560,\n", ")\n", "\n", "# Retrieve the pixel values, for the bounding box in\n", "# the target projection. In this example we use the\n", "# the RGB, NIR and SWIR band groups.\n", "stack = stackstac.stack(\n", " items,\n", " bounds=bounds,\n", " snap_bounds=False,\n", " epsg=epsg,\n", " resolution=10,\n", " dtype=\"float32\",\n", " rescale=False,\n", " fill_value=0,\n", " assets=BAND_GROUPS[\"rgb\"] + BAND_GROUPS[\"nir\"] + BAND_GROUPS[\"swir\"],\n", " resampling=Resampling.nearest,\n", ")\n", "\n", "stack = stack.compute()\n", "\n", "stack.sel(band=[\"red\", \"green\", \"blue\"]).plot.imshow(\n", " row=\"time\", rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", ")" ] }, { "cell_type": "markdown", "id": "daebe4a4-1f1d-4c20-be93-f29cb16f3fd7", "metadata": {}, "source": [ "### Plot the near infrared\n", "Notice the significant signal starting on August 26th (first image after the major flood) but appearing more strongly on the next date as August 26th was fairly cloudy." ] }, { "cell_type": "code", "execution_count": null, "id": "05a94e7e-1046-447a-af8d-a33dbdb71686", "metadata": {}, "outputs": [], "source": [ "stack.sel(band=[\"nir\", \"nir\", \"nir\"]).plot.imshow(\n", " row=\"time\", rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", ")" ] }, { "cell_type": "markdown", "id": "c4af92a3-ff35-4857-a3c1-7affe5717c98", "metadata": {}, "source": [ "### Plot the first short-wave infrared band\n", "Notice the same significant signal starting on August 26th (first image after the major flood), again appearing more strongly on the next date as August 26th was fairly cloudy." ] }, { "cell_type": "code", "execution_count": null, "id": "c84c9e0a-842d-41a8-bf76-665734b27e4e", "metadata": {}, "outputs": [], "source": [ "stack.sel(band=[\"swir16\", \"swir16\", \"swir16\"]).plot.imshow(\n", " row=\"time\", rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", ")" ] }, { "cell_type": "markdown", "id": "40a687f7-1558-4261-925a-74b3c524e308", "metadata": {}, "source": [ "### Plot the second short-wave infrared band\n", "Notice the same significant signal starting on August 26th (first image after the major flood), again appearing more strongly on the next date as August 26th was fairly cloudy." ] }, { "cell_type": "code", "execution_count": null, "id": "ecb08fd4-9c84-43cf-bdf7-362d142b614a", "metadata": {}, "outputs": [], "source": [ "stack.sel(band=[\"swir22\", \"swir22\", \"swir22\"]).plot.imshow(\n", " row=\"time\", rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", ")" ] }, { "cell_type": "markdown", "id": "77e7c22c-1bfd-4281-bb12-8330c3eedc25", "metadata": {}, "source": [ "## Write data to tif files\n", "To use the mini datacube in the Clay dataloader, we need to write the\n", "images to tif files on disk. These tif files are then used by the Clay\n", "data loader for creating embeddings below." ] }, { "cell_type": "code", "execution_count": null, "id": "6509c3b2-a67c-447d-a7a1-e5fbcc1e35b5", "metadata": {}, "outputs": [], "source": [ "outdir = Path(\"data/minicubes\")\n", "outdir.mkdir(exist_ok=True, parents=True)\n", "\n", "# Write tile to output dir\n", "for tile in stack:\n", " # Grid code like MGRS-29SNB\n", " mgrs = str(tile.coords[\"grid:code\"].values).split(\"-\")[1]\n", " date = str(tile.time.values)[:10]\n", "\n", " name = \"{dir}/claytile_{mgrs}_{date}.tif\".format(\n", " dir=outdir,\n", " mgrs=mgrs,\n", " date=date.replace(\"-\", \"\"),\n", " )\n", " tile.rio.to_raster(name, compress=\"deflate\")\n", "\n", " with rasterio.open(name, \"r+\") as rst:\n", " rst.update_tags(date=date)" ] }, { "cell_type": "markdown", "id": "ebc4b6ee-db58-4005-9689-a7d0acdc6a79", "metadata": { "scrolled": true }, "source": [ "## Create embeddings\n", "Now switch gears and load the tiles to create embeddings and analyze them. \n", "\n", "The model checkpoint can be loaded directly from huggingface, and the data\n", "directory points to the directory we created in the steps above.\n", "\n", "Note that the normalization parameters for the data module need to be \n", "adapted based on the band groups that were selected as partial input. The\n", "full set of normalization parameters can be found [here](https://github.com/Clay-foundation/model/blob/main/src/datamodule.py#L108)." ] }, { "cell_type": "markdown", "id": "d89e0135-9473-4f76-9f09-e4e295dd51c9", "metadata": {}, "source": [ "### Load the model and set up the data module" ] }, { "cell_type": "code", "execution_count": null, "id": "301ee2db-c5fc-4628-b837-12e6ea477415", "metadata": {}, "outputs": [], "source": [ "DATA_DIR = \"data/minicubes\"\n", "CKPT_PATH = \"https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"\n", "\n", "# Load model\n", "multi_model = CLAYModule.load_from_checkpoint(\n", " CKPT_PATH,\n", " mask_ratio=0.0,\n", " band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,), \"swir\": (4, 5)},\n", " bands=6,\n", " strict=False, # ignore the extra parameters in the checkpoint\n", ")\n", "# Set the model to evaluation mode\n", "multi_model.eval()\n", "\n", "\n", "# Load the datamodule, with the reduced set of\n", "class ClayDataModuleMulti(ClayDataModule):\n", " MEAN = [\n", " 1369.03, # red\n", " 1597.68, # green\n", " 1741.10, # blue\n", " 2893.86, # nir\n", " 2303.00, # swir16\n", " 1807.79, # swir22\n", " ]\n", " STD = [\n", " 2026.96, # red\n", " 2011.88, # green\n", " 2146.35, # blue\n", " 1917.12, # nir\n", " 1679.88, # swir16\n", " 1568.06, # swir22\n", " ]\n", "\n", "\n", "data_dir = Path(DATA_DIR)\n", "\n", "dm = ClayDataModuleMulti(data_dir=str(data_dir.absolute()), batch_size=2)\n", "dm.setup(stage=\"predict\")\n", "trn_dl = iter(dm.predict_dataloader())" ] }, { "cell_type": "markdown", "id": "db3f3e5e-8668-4830-9c77-cc1d8cb35234", "metadata": {}, "source": [ "### Create the embeddings for the images over the flood event\n", "This will loop through the images returned by the data loader\n", "and evaluate the model for each one of the images. The raw\n", "embeddings are reduced to mean values to simplify the data." ] }, { "cell_type": "code", "execution_count": null, "id": "04da4fe7-4243-4016-ab1e-367aa20d20fd", "metadata": {}, "outputs": [], "source": [ "embeddings = []\n", "ts = []\n", "for batch in trn_dl:\n", " with torch.no_grad():\n", " # Move data from to the device of model\n", " batch[\"pixels\"] = batch[\"pixels\"].to(multi_model.device)\n", " # Pass just the specific band through the model\n", " batch[\"timestep\"] = batch[\"timestep\"].to(multi_model.device)\n", " batch[\"date\"] = batch[\"date\"] # .to(multi_model.device)\n", " batch[\"latlon\"] = batch[\"latlon\"].to(multi_model.device)\n", "\n", " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", " (\n", " unmasked_patches,\n", " unmasked_indices,\n", " masked_indices,\n", " masked_matrix,\n", " ) = multi_model.model.encoder(batch)\n", "\n", " embeddings.append(unmasked_patches.detach().cpu().numpy())\n", " ts.append(batch[\"date\"])\n", "\n", "embeddings = numpy.vstack(embeddings)\n", "\n", "embeddings_mean = embeddings[:, :-2, :].mean(axis=1)" ] }, { "cell_type": "code", "execution_count": null, "id": "8042b3a0-831d-41fb-800b-c11c2e9068b5", "metadata": {}, "outputs": [], "source": [ "print(f\"Average embeddings have shape {embeddings_mean.shape}\")" ] }, { "cell_type": "markdown", "id": "79b3fae9", "metadata": {}, "source": [ "Check the dates. Notice they are in sublists of size 2 because of the batch size." ] }, { "cell_type": "code", "execution_count": null, "id": "c327fec5-4d94-4e77-bfb3-fd3b732d35bd", "metadata": {}, "outputs": [], "source": [ "ts" ] }, { "cell_type": "markdown", "id": "8371aedc", "metadata": {}, "source": [ "Flatten the dates" ] }, { "cell_type": "code", "execution_count": null, "id": "d3b9be81-8083-4d1b-888f-1cc09cdd1003", "metadata": {}, "outputs": [], "source": [ "tss = [t for tb in ts for t in tb]" ] }, { "cell_type": "code", "execution_count": null, "id": "6aca5f7e-31ce-4326-b98d-48783379de4b", "metadata": {}, "outputs": [], "source": [ "tss" ] }, { "cell_type": "markdown", "id": "72db5745-21c6-4b8e-b8f7-cb48c0f9c9ef", "metadata": {}, "source": [ "## Analyze embeddings\n", "Now we can make a simple analysis of the embeddings. We reduce all the\n", "embeddings to a single number using Principle Component Analysis. Then\n", "we can plot the principal components. The effect of the flood on the\n", "embeddings is clearly visible. We use the following color code in the graph:\n", "\n", "| Color | Interpretation |\n", "|---|---|\n", "| Green | Cloudy Images |\n", "| Blue | Before the flood |\n", "| Red | After the flood |" ] }, { "cell_type": "code", "execution_count": null, "id": "e95d5ddb-daa9-4753-9470-d709b8235ae7", "metadata": {}, "outputs": [], "source": [ "pca = decomposition.PCA(n_components=1)\n", "pca_result = pca.fit_transform(embeddings_mean)\n", "\n", "plt.xticks(rotation=-30)\n", "# All points\n", "plt.scatter(tss, pca_result, color=\"blue\")\n", "# plt.scatter(stack.time, pca_result, color=\"blue\")\n", "\n", "# Cloudy images\n", "plt.scatter(tss[7], pca_result[7], color=\"green\")\n", "plt.scatter(tss[8], pca_result[8], color=\"green\")\n", "# plt.scatter(stack.time[7], pca_result[7], color=\"green\")\n", "# plt.scatter(stack.time[8], pca_result[8], color=\"green\")\n", "\n", "# After flood\n", "plt.scatter(tss[-7:], pca_result[-7:], color=\"red\")\n", "# plt.scatter(stack.time[-7:], pca_result[-7:], color=\"red\")" ] }, { "cell_type": "markdown", "id": "a16fbdb8-1c2d-4c84-8526-283fa14faa53", "metadata": {}, "source": [ "In the plot above, each image embedding is one point. One can clearly \n", "distinguish the two cloudy images and the values after the flood are\n", "consistently low." ] }, { "cell_type": "markdown", "id": "fd7a8eb3-8c7e-4e8b-87bc-618a6801ed8f", "metadata": {}, "source": [ "## t-SNE example\n", "A quick t-SNE calculation shows that the dates indeed cluster as we'd expect, with the before flood dates grouped together, and the after-flood days together." ] }, { "cell_type": "code", "execution_count": null, "id": "f32e7726-e468-4aea-858b-6d32fe64a83a", "metadata": {}, "outputs": [], "source": [ "from sklearn.manifold import TSNE\n", "\n", "# Perform t-SNE on the embeddings\n", "tsne = TSNE(n_components=2, perplexity=5)\n", "X_tsne = tsne.fit_transform(embeddings_mean)" ] }, { "cell_type": "code", "execution_count": null, "id": "24c4f8ff-642a-4236-bf33-4b88599cb922", "metadata": {}, "outputs": [], "source": [ "# Plot the results\n", "plt.figure(figsize=(10, 6))\n", "plt.scatter(X_tsne[:, 0], X_tsne[:, 1])\n", "\n", "# Annotate each point with the corresponding date\n", "for i, (x, y) in enumerate(zip(X_tsne[:, 0], X_tsne[:, 1])):\n", " plt.annotate(f\"{tss[i]}\", (x, y))\n", "\n", "plt.title(\"t-SNE Visualization\")\n", "plt.xlabel(\"t-SNE Component 1\")\n", "plt.ylabel(\"t-SNE Component 2\")\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "eeb19b34-353e-4184-b65a-0279022afc4d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }