{ "cells": [ { "cell_type": "markdown", "id": "8dd35554-a9dd-49cf-b9fa-24fa8ae6cecf", "metadata": {}, "source": [ "# Burn scar analysis using embeddings from partial inputs\n", "This notebook contains a complete example for how to run Clay. It\n", "combines the following three different aspects\n", "\n", "1. Create single-chip datacubes with time series data for a location and a date range\n", "2. Run the model with partial inputs, in this case RGB + NIR\n", "3. Study burn scares through the embeddings generated for that datacube\n", "\n", "## Let's start with importing and creating constants" ] }, { "cell_type": "code", "execution_count": 1, "id": "b7bcff1e-bdb5-47f8-aa0e-d68d6fdd3476", "metadata": {}, "outputs": [], "source": [ "# Ensure working directory is the repo home\n", "import os\n", "\n", "os.chdir(\"..\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "15d65ec9-86aa-4275-89ba-ec79fdbad361", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "from pathlib import Path\n", "\n", "import geopandas as gpd\n", "import matplotlib.pyplot as plt\n", "import numpy\n", "import pandas as pd\n", "import pystac_client\n", "import rasterio\n", "import rioxarray # noqa: F401\n", "import stackstac\n", "import torch\n", "from rasterio.enums import Resampling\n", "from shapely import Point\n", "from sklearn import decomposition\n", "\n", "from src.datamodule import ClayDataModule\n", "from src.model_clay import CLAYModule\n", "\n", "warnings.filterwarnings(\"ignore\")\n", "\n", "BAND_GROUPS = {\n", " \"rgb\": [\"red\", \"green\", \"blue\"],\n", " \"rededge\": [\"rededge1\", \"rededge2\", \"rededge3\", \"nir08\"],\n", " \"nir\": [\n", " \"nir\",\n", " ],\n", " \"swir\": [\"swir16\", \"swir22\"],\n", " \"sar\": [\"vv\", \"vh\"],\n", "}\n", "\n", "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", "COLLECTION = \"sentinel-2-l2a\"" ] }, { "cell_type": "markdown", "id": "a6341305-9c44-4a1e-847c-80d77b01c0bf", "metadata": {}, "source": [ "## Search for imagery over an area of interest\n", "In this example we use a location and date range to visualize a forest fire that happened in [Monchique in 2018](https://pt.wikipedia.org/wiki/Inc%C3%AAndio_de_Monchique_de_2018)" ] }, { "cell_type": "code", "execution_count": 3, "id": "a1886f5a-8669-40e7-8fae-e45619570e3c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Found 12 items\n" ] } ], "source": [ "# Point over Monchique Portugal\n", "poi = 37.30939, -8.57207\n", "\n", "# Dates of a large forest fire\n", "start = \"2018-07-01\"\n", "end = \"2018-09-01\"\n", "\n", "catalog = pystac_client.Client.open(STAC_API)\n", "\n", "search = catalog.search(\n", " collections=[COLLECTION],\n", " datetime=f\"{start}/{end}\",\n", " bbox=(poi[1] - 1e-5, poi[0] - 1e-5, poi[1] + 1e-5, poi[0] + 1e-5),\n", " max_items=100,\n", " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", ")\n", "\n", "items = search.get_all_items()\n", "\n", "print(f\"Found {len(items)} items\")" ] }, { "cell_type": "markdown", "id": "c4ba5c36-90a6-427c-80c5-2a83ad11a1b0", "metadata": {}, "source": [ "## Download the data\n", "Get the data into a numpy array and visualize the imagery. The burn scar is visible in the last five images." ] }, { "cell_type": "code", "execution_count": null, "id": "c371501c-3ef0-4507-9073-0521a1c733be", "metadata": {}, "outputs": [], "source": [ "# Extract coordinate system from first item\n", "epsg = items[0].properties[\"proj:epsg\"]\n", "\n", "# Convert point into the image projection\n", "poidf = gpd.GeoDataFrame(\n", " pd.DataFrame(),\n", " crs=\"EPSG:4326\",\n", " geometry=[Point(poi[1], poi[0])],\n", ").to_crs(epsg)\n", "\n", "coords = poidf.iloc[0].geometry.coords[0]\n", "\n", "# Create bounds of the correct size, the model\n", "# requires 512x512 pixels at 10m resolution.\n", "bounds = (\n", " coords[0] - 2560,\n", " coords[1] - 2560,\n", " coords[0] + 2560,\n", " coords[1] + 2560,\n", ")\n", "\n", "# Retrieve the pixel values, for the bounding box in\n", "# the target projection. In this example we use only\n", "# the RGB and NIR band groups.\n", "stack = stackstac.stack(\n", " items,\n", " bounds=bounds,\n", " snap_bounds=False,\n", " epsg=epsg,\n", " resolution=10,\n", " dtype=\"float32\",\n", " rescale=False,\n", " fill_value=0,\n", " assets=BAND_GROUPS[\"rgb\"] + BAND_GROUPS[\"nir\"],\n", " resampling=Resampling.nearest,\n", ")\n", "\n", "stack = stack.compute()\n", "\n", "stack.sel(band=[\"red\", \"green\", \"blue\"]).plot.imshow(\n", " row=\"time\", rgb=\"band\", vmin=0, vmax=2000, col_wrap=6\n", ")" ] }, { "cell_type": "markdown", "id": "ce633fb1-fc82-4c88-8204-cda47aa9c874", "metadata": {}, "source": [ "![Minicube visualization](https://github.com/Clay-foundation/model/assets/901647/c6e924e5-6ba1-4924-b99a-df8b90731a5f)" ] }, { "cell_type": "markdown", "id": "77e7c22c-1bfd-4281-bb12-8330c3eedc25", "metadata": {}, "source": [ "## Write data to tif files\n", "To use the mini datacube in the Clay dataloader, we need to write the\n", "images to tif files on disk. These tif files are then used by the Clay\n", "data loader for creating embeddings below." ] }, { "cell_type": "code", "execution_count": 5, "id": "6509c3b2-a67c-447d-a7a1-e5fbcc1e35b5", "metadata": {}, "outputs": [], "source": [ "outdir = Path(\"data/minicubes\")\n", "outdir.mkdir(exist_ok=True, parents=True)\n", "\n", "# Write tile to output dir\n", "for tile in stack:\n", " # Grid code like MGRS-29SNB\n", " mgrs = str(tile.coords[\"grid:code\"].values).split(\"-\")[1]\n", " date = str(tile.time.values)[:10]\n", "\n", " name = \"{dir}/claytile_{mgrs}_{date}.tif\".format(\n", " dir=outdir,\n", " mgrs=mgrs,\n", " date=date.replace(\"-\", \"\"),\n", " )\n", " tile.rio.to_raster(name, compress=\"deflate\")\n", "\n", " with rasterio.open(name, \"r+\") as rst:\n", " rst.update_tags(date=date)" ] }, { "cell_type": "markdown", "id": "ebc4b6ee-db58-4005-9689-a7d0acdc6a79", "metadata": { "scrolled": true }, "source": [ "## Create embeddings\n", "Now switch gears and load the tiles to create embeddings and analyze them. \n", "\n", "The model checkpoint can be loaded directly from huggingface, and the data\n", "directory points to the directory we created in the steps above.\n", "\n", "Note that the normalization parameters for the data module need to be \n", "adapted based on the band groups that were selected as partial input. The\n", "full set of normalization parameters can be found [here](https://github.com/Clay-foundation/model/blob/main/src/datamodule.py#L108)." ] }, { "cell_type": "markdown", "id": "d89e0135-9473-4f76-9f09-e4e295dd51c9", "metadata": {}, "source": [ "### Load the model and set up the data module" ] }, { "cell_type": "code", "execution_count": 6, "id": "301ee2db-c5fc-4628-b837-12e6ea477415", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of chips: 12\n" ] } ], "source": [ "DATA_DIR = \"data/minicubes\"\n", "CKPT_PATH = \"https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"\n", "\n", "# Load model\n", "rgb_model = CLAYModule.load_from_checkpoint(\n", " CKPT_PATH,\n", " mask_ratio=0.0,\n", " band_groups={\"rgb\": (2, 1, 0), \"nir\": (3,)},\n", " bands=4,\n", " strict=False, # ignore the extra parameters in the checkpoint\n", ")\n", "# Set the model to evaluation mode\n", "rgb_model.eval()\n", "\n", "\n", "# Load the datamodule, with the reduced set of\n", "class ClayDataModuleRGB(ClayDataModule):\n", " MEAN = [\n", " 1369.03, # red\n", " 1597.68, # green\n", " 1741.10, # blue\n", " 2858.43, # nir\n", " ]\n", " STD = [\n", " 2026.96, # red\n", " 2011.88, # green\n", " 2146.35, # blue\n", " 2016.38, # nir\n", " ]\n", "\n", "\n", "data_dir = Path(DATA_DIR)\n", "\n", "dm = ClayDataModuleRGB(data_dir=str(data_dir.absolute()), batch_size=20)\n", "dm.setup(stage=\"predict\")\n", "trn_dl = iter(dm.predict_dataloader())" ] }, { "cell_type": "markdown", "id": "db3f3e5e-8668-4830-9c77-cc1d8cb35234", "metadata": {}, "source": [ "### Create the embeddings for the images over the forest fire\n", "This will loop through the images returned by the data loader\n", "and evaluate the model for each one of the images. The raw\n", "embeddings are reduced to mean values to simplify the data." ] }, { "cell_type": "code", "execution_count": 7, "id": "c5762240-9d22-4ebd-8e39-83fc6594a459", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Average embeddings have shape (12, 768)\n" ] } ], "source": [ "embeddings = []\n", "\n", "for batch in trn_dl:\n", " with torch.inference_mode():\n", " # Move data from to the device of model\n", " batch[\"pixels\"] = batch[\"pixels\"].to(rgb_model.device)\n", " # Pass just the specific band through the model\n", " batch[\"timestep\"] = batch[\"timestep\"].to(rgb_model.device)\n", " batch[\"latlon\"] = batch[\"latlon\"].to(rgb_model.device)\n", "\n", " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", " (\n", " unmasked_patches,\n", " unmasked_indices,\n", " masked_indices,\n", " masked_matrix,\n", " ) = rgb_model.model.encoder(batch)\n", "\n", " embeddings.append(unmasked_patches.detach().cpu().numpy())\n", "\n", "embeddings = numpy.vstack(embeddings)\n", "\n", "embeddings_mean = embeddings[:, :-2, :].mean(axis=1)\n", "\n", "print(f\"Average embeddings have shape {embeddings_mean.shape}\")" ] }, { "cell_type": "markdown", "id": "72db5745-21c6-4b8e-b8f7-cb48c0f9c9ef", "metadata": {}, "source": [ "## Analyze embeddings\n", "Now we can make a simple analysis of the embeddings. We reduce all the\n", "embeddings to a single number using Principle Component Analysis. Then\n", "we can plot the principal components. The effect of the fire on the\n", "embeddings is clearly visible. We use the following color code in the graph:\n", "\n", "| Color | Interpretation |\n", "|---|---|\n", "| Green | Cloudy Images |\n", "| Blue | Before the fire |\n", "| Red | After the fire |" ] }, { "cell_type": "code", "execution_count": 8, "id": "88f3b2dc-8f2a-447b-a6af-b04e0d1ff61c", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "pca = decomposition.PCA(n_components=1)\n", "pca_result = pca.fit_transform(embeddings_mean)\n", "\n", "plt.xticks(rotation=-30)\n", "# All points\n", "plt.scatter(stack.time, pca_result, color=\"blue\")\n", "\n", "# Cloudy images\n", "plt.scatter(stack.time[0], pca_result[0], color=\"green\")\n", "plt.scatter(stack.time[2], pca_result[2], color=\"green\")\n", "\n", "# After fire\n", "plt.scatter(stack.time[-5:], pca_result[-5:], color=\"red\")" ] }, { "cell_type": "markdown", "id": "a16fbdb8-1c2d-4c84-8526-283fa14faa53", "metadata": {}, "source": [ "In the plot above, each image embedding is one point. One can clearly \n", "distinguish the two cloudy images and the values after the fire are\n", "consistently low." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.6" } }, "nbformat": 4, "nbformat_minor": 5 }