{ "cells": [ { "cell_type": "markdown", "id": "287c93e9", "metadata": {}, "source": [ "# Patch level cloud coverage\n", "\n", "This notebook obtains patch level (32x32 subset) cloud cover percentages from the Scene classification mask tied to a Sentinel-2 dataset.\n", "\n", "We will demonstrate how to do the following:\n", "\n", "1. Leverage the [AWS Sentinel-2 STAC catalog](https://registry.opendata.aws/sentinel-2/) to obtain fine \"patch\" level cloud cover percentages from the [Scene classification (SCL) mask](https://docs.sentinel-hub.com/api/latest/data/sentinel-2-l2a/). These percentages will be mapped and added to the GeoParquet files such that they can be added to database tables for similarity search filters and any other relevant downstream tasks.\n", "2. Generate fine level (pixel of size 10m x 10m) embeddings for an area (5.12km x 5.12km).\n", "3. Save the fine level (patch) embeddings and execute a similarity search that leverages the cloud cover percentages as reference." ] }, { "cell_type": "code", "execution_count": null, "id": "9b0dbed2", "metadata": {}, "outputs": [], "source": [ "import glob\n", "from pathlib import Path\n", "\n", "import geopandas as gpd\n", "import lancedb\n", "import matplotlib.pyplot as plt\n", "import numpy\n", "import pandas as pd\n", "import pystac_client\n", "import rasterio\n", "import rioxarray # noqa: F401\n", "import shapely\n", "import stackstac\n", "import torch\n", "from rasterio.enums import Resampling\n", "from shapely.geometry import Polygon, box\n", "\n", "from src.datamodule import ClayDataModule\n", "from src.model_clay import CLAYModule\n", "\n", "pd.set_option(\"display.max_colwidth\", None)\n", "\n", "BAND_GROUPS_L2A = {\n", " \"rgb\": [\"red\", \"green\", \"blue\"],\n", " \"rededge\": [\"rededge1\", \"rededge2\", \"rededge3\", \"nir08\"],\n", " \"nir\": [\n", " \"nir\",\n", " ],\n", " \"swir\": [\"swir16\", \"swir22\"],\n", " \"sar\": [\"vv\", \"vh\"],\n", " \"scl\": [\"scl\"],\n", "}\n", "\n", "STAC_API_L2A = \"https://earth-search.aws.element84.com/v1\"\n", "COLLECTION_L2A = \"sentinel-2-l2a\"\n", "\n", "SCL_CLOUD_LABELS = [7, 8, 9, 10]" ] }, { "cell_type": "markdown", "id": "03fa77d3", "metadata": {}, "source": [ "## Find Sentinel-2 scenes stored as Cloud-Optimized GeoTIFFs" ] }, { "cell_type": "markdown", "id": "fda4d759-2c2b-430a-b4b8-21ebd0c4ccbb", "metadata": {}, "source": [ "#### Define an area of interest\n", "This is a hotspot area where mining extraction occurs on the island of Fiji. We used this in another tutorial, albeit with a cloud free composite. This will help demonstrate how we can capture clouds for the same region and time frame in the absence of a cloud-free composite." ] }, { "cell_type": "code", "execution_count": null, "id": "adb0b69a-faa2-4bfc-9cdd-d2294de9bec8", "metadata": {}, "outputs": [], "source": [ "# sample cluster\n", "bbox_bl = (177.4199, -17.8579)\n", "bbox_tl = (177.4156, -17.6812)\n", "bbox_br = (177.5657, -17.8572)\n", "bbox_tr = (177.5657, -17.6812)" ] }, { "cell_type": "markdown", "id": "36e30686-7555-4866-8157-752418174140", "metadata": {}, "source": [ "Define spatiotemporal query" ] }, { "cell_type": "code", "execution_count": null, "id": "2a1ac2ef", "metadata": {}, "outputs": [], "source": [ "# Define area of interest\n", "area_of_interest = shapely.box(\n", " xmin=bbox_bl[0], ymin=bbox_bl[1], xmax=bbox_tr[0], ymax=bbox_tr[1]\n", ")\n", "\n", "# Define temporal range\n", "daterange: dict = [\"2021-01-01T00:00:00Z\", \"2021-12-31T23:59:59Z\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "ba4e6aeb", "metadata": {}, "outputs": [], "source": [ "catalog_L2A = pystac_client.Client.open(STAC_API_L2A)\n", "\n", "search = catalog_L2A.search(\n", " collections=[COLLECTION_L2A],\n", " datetime=daterange,\n", " intersects=area_of_interest,\n", " max_items=100,\n", " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", ")\n", "\n", "items_L2A = search.get_all_items()\n", "\n", "print(f\"Found {len(items_L2A)} items\")" ] }, { "cell_type": "markdown", "id": "90d5183d", "metadata": {}, "source": [ "## Download the data\n", "Get the data into a numpy array and visualize the imagery. " ] }, { "cell_type": "code", "execution_count": null, "id": "3cccf2b9", "metadata": {}, "outputs": [], "source": [ "# Extract coordinate system from first item\n", "epsg = items_L2A[0].properties[\"proj:epsg\"]\n", "\n", "# Convert point from lon/lat to UTM projection\n", "poidf = gpd.GeoDataFrame(crs=\"OGC:CRS84\", geometry=[area_of_interest.centroid]).to_crs(\n", " epsg\n", ")\n", "geom = poidf.iloc[0].geometry\n", "\n", "# Create bounds of the correct size, the model\n", "# requires 512x512 pixels at 10m resolution.\n", "bounds = (geom.x - 2560, geom.y - 2560, geom.x + 2560, geom.y + 2560)\n", "\n", "# Retrieve the pixel values, for the bounding box in\n", "# the target projection. In this example we use only\n", "# the RGB and SCL band groups.\n", "stack_L2A = stackstac.stack(\n", " items_L2A[0],\n", " bounds=bounds,\n", " snap_bounds=False,\n", " epsg=epsg,\n", " resolution=10,\n", " dtype=\"float32\",\n", " rescale=False,\n", " fill_value=0,\n", " assets=BAND_GROUPS_L2A[\"rgb\"] + BAND_GROUPS_L2A[\"scl\"],\n", " resampling=Resampling.nearest,\n", " xy_coords=\"center\",\n", ")\n", "\n", "stack_L2A = stack_L2A.compute()\n", "print(stack_L2A.shape)\n", "assert stack_L2A.shape == (1, 4, 512, 512)\n", "\n", "stack_L2A.sel(band=[\"red\", \"green\", \"blue\"]).plot.imshow(\n", " row=\"time\", rgb=\"band\", vmin=0, vmax=2000\n", ")" ] }, { "cell_type": "markdown", "id": "a0154b7e-9b75-4bfa-9bd5-c36865d643c1", "metadata": {}, "source": [ "#### Write the stack to file" ] }, { "cell_type": "code", "execution_count": null, "id": "c3973124-c4ea-4e19-984c-19400ae5a2a0", "metadata": {}, "outputs": [], "source": [ "outdir = Path(\"data/minicubes_cloud\")\n", "outdir.mkdir(exist_ok=True, parents=True)\n", "\n", "write = True\n", "if write:\n", " # Write tile to output dir, whilst dropping the SCL band in the process\n", " for tile in stack_L2A.sel(band=[\"red\", \"green\", \"blue\"]):\n", " date = str(tile.time.values)[:10]\n", "\n", " name = \"{dir}/claytile_{date}.tif\".format(\n", " dir=outdir,\n", " date=date.replace(\"-\", \"\"),\n", " )\n", " tile.rio.to_raster(name, compress=\"deflate\")\n", "\n", " with rasterio.open(name, \"r+\") as rst:\n", " rst.update_tags(date=date)" ] }, { "cell_type": "markdown", "id": "1d0f37d4-35bf-4aa8-afb6-20f37d5739c2", "metadata": {}, "source": [ "### Get the geospatial bounds and cloud cover percentages for the 32x32 windows \n", "We will use the geospatial bounds of the 32x32 windowed subsets (\"chunks\") to store the patch level embeddings." ] }, { "cell_type": "code", "execution_count": null, "id": "9e50d97e-e756-4584-acb0-5af01ad0cc75", "metadata": {}, "outputs": [], "source": [ "# Function to count cloud pixels in a subset\n", "\n", "\n", "def count_cloud_pixels(subset_scl, cloud_labels):\n", " cloud_pixels = 0\n", " for label in cloud_labels:\n", " cloud_pixels += numpy.count_nonzero(subset_scl == label)\n", " return cloud_pixels" ] }, { "cell_type": "code", "execution_count": null, "id": "22aa5c54-5336-4b52-a927-6748448a146d", "metadata": {}, "outputs": [], "source": [ "# Define the chunk size for tiling\n", "chunk_size = {\"x\": 32, \"y\": 32} # Adjust the chunk size as needed\n", "\n", "# Tile the data\n", "ds_chunked_L2A = stack_L2A.chunk(chunk_size)\n", "\n", "# Get the dimensions of the data array\n", "dims = ds_chunked_L2A.dims\n", "\n", "# Get the geospatial information from the original dataset\n", "geo_info = ds_chunked_L2A.attrs\n", "\n", "# Iterate over the chunks and compute the geospatial bounds for each chunk\n", "chunk_bounds = {}\n", "\n", "# Iterate over the chunks and compute the cloud count for each chunk\n", "cloud_pcts = {}\n", "\n", "# Get the geospatial transform and CRS\n", "transform = ds_chunked_L2A.attrs[\"transform\"]\n", "crs = ds_chunked_L2A.attrs[\"crs\"]\n", "\n", "for x in range(ds_chunked_L2A.sizes[\"x\"] // chunk_size[\"x\"]): # + 1):\n", " for y in range(ds_chunked_L2A.sizes[\"y\"] // chunk_size[\"y\"]): # + 1):\n", " # Compute chunk coordinates\n", " x_start = x * chunk_size[\"x\"]\n", " y_start = y * chunk_size[\"y\"]\n", " x_end = min(x_start + chunk_size[\"x\"], ds_chunked_L2A.sizes[\"x\"])\n", " y_end = min(y_start + chunk_size[\"y\"], ds_chunked_L2A.sizes[\"y\"])\n", "\n", " # Compute chunk geospatial bounds\n", " lon_start, lat_start = transform * (x_start, y_start)\n", " lon_end, lat_end = transform * (x_end, y_end)\n", " # print(lon_start, lat_start, lon_end, lat_end, x, y)\n", "\n", " # Store chunk bounds\n", " chunk_bounds[(x, y)] = {\n", " \"lon_start\": lon_start,\n", " \"lat_start\": lat_start,\n", " \"lon_end\": lon_end,\n", " \"lat_end\": lat_end,\n", " }\n", "\n", " # Extract the subset of the SCL band\n", " subset_scl = ds_chunked_L2A.sel(band=\"scl\")[:, y_start:y_end, x_start:x_end]\n", "\n", " # Count the cloud pixels in the subset\n", " cloud_pct = count_cloud_pixels(subset_scl, SCL_CLOUD_LABELS)\n", "\n", " # Store the cloud percent for this chunk\n", " cloud_pcts[(x, y)] = int(100 * (cloud_pct / 1024))\n", "\n", "\n", "# Print chunk bounds\n", "# for key, value in chunk_bounds.items():\n", "# print(f\"Chunk {key}: {value}\")\n", "\n", "# Print indices where cloud percentages exceed some interesting threshold\n", "cloud_threshold = 50\n", "for key, value in cloud_pcts.items():\n", " if value > cloud_threshold:\n", " print(f\"Chunk {key}: Cloud percentage = {value}\")" ] }, { "cell_type": "code", "execution_count": null, "id": "d4673aa7-24f5-4d20-b283-5c446020f3d5", "metadata": {}, "outputs": [], "source": [ "DATA_DIR = \"data/minicubes_cloud\"\n", "CKPT_PATH = (\n", " \"https://huggingface.co/made-with-clay/Clay/resolve/main/\"\n", " \"Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"\n", ")\n", "# Load model\n", "multi_model = CLAYModule.load_from_checkpoint(\n", " CKPT_PATH,\n", " mask_ratio=0.0,\n", " band_groups={\"rgb\": (2, 1, 0)},\n", " bands=3,\n", " strict=False, # ignore the extra parameters in the checkpoint\n", " embeddings_level=\"group\",\n", ")\n", "# Set the model to evaluation mode\n", "multi_model.eval()\n", "\n", "\n", "# Load the datamodule, with the reduced set of\n", "class ClayDataModuleMulti(ClayDataModule):\n", " MEAN = [\n", " 1369.03, # red\n", " 1597.68, # green\n", " 1741.10, # blue\n", " ]\n", " STD = [\n", " 2026.96, # red\n", " 2011.88, # green\n", " 2146.35, # blue\n", " ]\n", "\n", "\n", "data_dir = Path(DATA_DIR)\n", "\n", "dm = ClayDataModuleMulti(data_dir=str(data_dir.absolute()), batch_size=1)\n", "dm.setup(stage=\"predict\")\n", "trn_dl = iter(dm.predict_dataloader())" ] }, { "cell_type": "code", "execution_count": null, "id": "eefa3ecc-5aff-43e8-8e8a-e6ea86ed5b07", "metadata": {}, "outputs": [], "source": [ "embeddings = []\n", "for batch in trn_dl:\n", " with torch.no_grad():\n", " # Move data from to the device of model\n", " batch[\"pixels\"] = batch[\"pixels\"].to(multi_model.device)\n", " # Pass just the specific band through the model\n", " batch[\"timestep\"] = batch[\"timestep\"].to(multi_model.device)\n", " batch[\"latlon\"] = batch[\"latlon\"].to(multi_model.device)\n", "\n", " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", " (\n", " unmasked_patches,\n", " unmasked_indices,\n", " masked_indices,\n", " masked_matrix,\n", " ) = multi_model.model.encoder(batch)\n", " print(unmasked_patches.detach().cpu().numpy())\n", "\n", " embeddings.append(unmasked_patches.detach().cpu().numpy())" ] }, { "cell_type": "code", "execution_count": null, "id": "aae59ecf-2316-4861-b15c-99e90873a04d", "metadata": {}, "outputs": [], "source": [ "print(len(embeddings[0])) # embeddings is a list\n", "print(embeddings[0].shape) # with date and lat/lon\n", "print(embeddings[0][:, :-2, :].shape) # remove date and lat/lon" ] }, { "cell_type": "code", "execution_count": null, "id": "8dbecc66-dce7-4265-b9fa-83f10070bc61", "metadata": {}, "outputs": [], "source": [ "# remove date and lat/lon and reshape to disaggregated patches\n", "embeddings_patch = embeddings[0][:, :-2, :].reshape([1, 16, 16, 768])" ] }, { "cell_type": "code", "execution_count": null, "id": "5ea85d1b-bd4d-45b6-b679-12d10e029707", "metadata": {}, "outputs": [], "source": [ "embeddings_patch.shape" ] }, { "cell_type": "code", "execution_count": null, "id": "e4e621eb-28c6-4aa4-83f2-4ad6c074e468", "metadata": {}, "outputs": [], "source": [ "# average over the band groups\n", "embeddings_patch_avg_group = embeddings_patch.mean(axis=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "4536912c-8919-47a9-9945-f3576ef4f7ab", "metadata": {}, "outputs": [], "source": [ "embeddings_patch_avg_group.shape" ] }, { "cell_type": "markdown", "id": "f7ed0c0e-e404-4c7e-a4b6-65403825b58a", "metadata": {}, "source": [ "### Save the patch level embeddings to independent GeoParquet files\n", "Save the patch level embeddings with the matching geospatial bounds and cloud cover percentages from the chunks we computed earlier. We are correlating patch to chunk bounds based on matching index. This assumes the patches and chunks both define 32x32 subsets with zero overlap." ] }, { "cell_type": "code", "execution_count": null, "id": "c061304a-f5f5-42a4-a62c-a93e17a7c547", "metadata": {}, "outputs": [], "source": [ "outdir_embeddings = Path(\"data/embeddings_cloud\")\n", "outdir_embeddings.mkdir(exist_ok=True, parents=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "5376859d-83d0-4d6b-8308-4ff43699bfdf", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Iterate through each patch\n", "for i in range(embeddings_patch_avg_group.shape[0]):\n", " for j in range(embeddings_patch_avg_group.shape[1]):\n", " embeddings_output_patch = embeddings_patch_avg_group[i, j]\n", "\n", " item_ = [\n", " element for element in list(chunk_bounds.items()) if element[0] == (i, j)\n", " ]\n", " box_ = [\n", " item_[0][1][\"lon_start\"],\n", " item_[0][1][\"lat_start\"],\n", " item_[0][1][\"lon_end\"],\n", " item_[0][1][\"lat_end\"],\n", " ]\n", " cloud_pct_ = [\n", " element for element in list(cloud_pcts.items()) if element[0] == (i, j)\n", " ]\n", " source_url = batch[\"source_url\"]\n", " date = batch[\"date\"]\n", " data = {\n", " \"source_url\": batch[\"source_url\"][0],\n", " \"date\": pd.to_datetime(arg=date, format=\"%Y-%m-%d\").astype(\n", " dtype=\"date32[day][pyarrow]\"\n", " ),\n", " \"embeddings\": [numpy.ascontiguousarray(embeddings_output_patch)],\n", " \"cloud_cover\": cloud_pct_[0][1],\n", " }\n", "\n", " # Define the bounding box as a Polygon (xmin, ymin, xmax, ymax)\n", " # The box_ list is encoded as\n", " # [bottom left x, bottom left y, top right x, top right y]\n", " box_emb = shapely.geometry.box(box_[0], box_[1], box_[2], box_[3])\n", "\n", " # Create the GeoDataFrame\n", " gdf = gpd.GeoDataFrame(data, geometry=[box_emb], crs=f\"EPSG:{epsg}\")\n", "\n", " # Reproject to WGS84 (lon/lat coordinates)\n", " gdf = gdf.to_crs(epsg=4326)\n", "\n", " outpath = (\n", " f\"{outdir_embeddings}/\"\n", " f\"{batch['source_url'][0].split('/')[-1][:-4]}_{i}_{j}.gpq\"\n", " )\n", " gdf.to_parquet(path=outpath, compression=\"ZSTD\", schema_version=\"1.0.0\")\n", " print(\n", " f\"Saved {len(gdf)} rows of embeddings of \"\n", " f\"shape {gdf.embeddings.iloc[0].shape} to {outpath}\"\n", " )" ] }, { "cell_type": "markdown", "id": "5a3d04d2-26d1-4979-b4af-7c72193c4041", "metadata": {}, "source": [ "### Similarity search on the patch embedding level\n", "We will use reference indices based on cloud cover percentage to define a filtered search." ] }, { "cell_type": "code", "execution_count": null, "id": "d0e2c9b5-5bd7-4356-b9e5-ebe6f2b65692", "metadata": {}, "outputs": [], "source": [ "db = lancedb.connect(\"embeddings\")" ] }, { "cell_type": "code", "execution_count": null, "id": "36787000-a1da-4e26-a9d4-120361720b6a", "metadata": {}, "outputs": [], "source": [ "# Data for DB table\n", "data = []\n", "# Dataframe to find overlaps within\n", "gdfs = []\n", "for emb in glob.glob(f\"{outdir_embeddings}/*.gpq\"):\n", " gdf = gpd.read_parquet(emb)\n", " gdf[\"year\"] = gdf.date.dt.year\n", " gdf[\"tile\"] = gdf[\"source_url\"].apply(\n", " lambda x: Path(x).stem.rsplit(\"/\")[-1].rsplit(\"_\")[0]\n", " )\n", " gdf[\"idx\"] = \"_\".join(emb.split(\"/\")[-1].split(\"_\")[2:]).replace(\".gpq\", \"\")\n", " gdf[\"box\"] = [box(*geom.bounds) for geom in gdf.geometry]\n", " gdfs.append(gdf)\n", "\n", " for _, row in gdf.iterrows():\n", " data.append(\n", " {\n", " \"vector\": row[\"embeddings\"],\n", " \"path\": row[\"source_url\"],\n", " \"tile\": row[\"tile\"],\n", " \"date\": row[\"date\"],\n", " \"year\": int(row[\"year\"]),\n", " \"cloud_cover\": row[\"cloud_cover\"],\n", " \"idx\": row[\"idx\"],\n", " \"box\": row[\"box\"].bounds,\n", " }\n", " )" ] }, { "cell_type": "code", "execution_count": null, "id": "4c9ca836-6370-495a-abd3-d5f42a4f5b55", "metadata": {}, "outputs": [], "source": [ "# Combine patch level geodataframes into one\n", "embeddings_gdf = pd.concat(gdfs, ignore_index=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "7e702ea0-fcd2-4c26-9cb1-54b5e86d98c6", "metadata": {}, "outputs": [], "source": [ "embeddings_gdf.columns" ] }, { "cell_type": "markdown", "id": "bf4ad489-32c1-444d-b0e4-d7c63bcd23b8", "metadata": {}, "source": [ "##### (Optional) check on what an embedding's RGB subset looks like" ] }, { "cell_type": "code", "execution_count": null, "id": "f026d88d-09d8-46df-9b91-13a9b9480251", "metadata": {}, "outputs": [], "source": [ "embeddings_gdf_shuffled = embeddings_gdf.sample(frac=1).reset_index(drop=True)\n", "\n", "area_of_interest_embedding = embeddings_gdf_shuffled.box.iloc[0]\n", "\n", "# Extract coordinate system from first item\n", "epsg = items_L2A[0].properties[\"proj:epsg\"]\n", "\n", "# Convert point from lon/lat to UTM projection\n", "box_embedding = gpd.GeoDataFrame(\n", " crs=\"OGC:CRS84\", geometry=[area_of_interest_embedding]\n", ").to_crs(epsg)\n", "geom_embedding = box_embedding.iloc[0].geometry\n", "\n", "# Create bounds of the correct size, the model\n", "# requires 32x32 pixels at 10m resolution.\n", "\n", "# Retrieve the pixel values, for the bounding box in\n", "# the target projection. In this example we use only\n", "# the RGB group.\n", "stack_embedding = stackstac.stack(\n", " items_L2A[0],\n", " bounds=geom_embedding.bounds,\n", " snap_bounds=False,\n", " epsg=epsg,\n", " resolution=10,\n", " dtype=\"float32\",\n", " rescale=False,\n", " fill_value=0,\n", " assets=BAND_GROUPS_L2A[\"rgb\"],\n", " resampling=Resampling.nearest,\n", " xy_coords=\"center\",\n", ")\n", "\n", "stack_embedding = stack_embedding.compute()\n", "assert stack_embedding.shape == (1, 3, 32, 32)\n", "\n", "stack_embedding.sel(band=[\"red\", \"green\", \"blue\"]).plot.imshow(\n", " row=\"time\", rgb=\"band\", vmin=0, vmax=2000\n", ")" ] }, { "cell_type": "markdown", "id": "50d69e03-bf2f-4f3a-86c0-eb466c1e8e68", "metadata": {}, "source": [ "#### Instantiate a dedicated DB table" ] }, { "cell_type": "code", "execution_count": null, "id": "b06f1d53-7edf-423e-9ffa-055c267600db", "metadata": {}, "outputs": [], "source": [ "db.drop_table(\"clay-v001\")\n", "db.table_names()" ] }, { "cell_type": "code", "execution_count": null, "id": "951dd959-844d-43cf-902d-5b0730c314ae", "metadata": {}, "outputs": [], "source": [ "tbl = db.create_table(\"clay-v001\", data=data, mode=\"overwrite\")" ] }, { "cell_type": "markdown", "id": "bd3d1cc1-9d79-4059-a1f6-4ac8cf4d2e51", "metadata": {}, "source": [ "#### Set up filtered searchs" ] }, { "cell_type": "code", "execution_count": null, "id": "48f57d1b-7601-418a-a758-eb56cf69ddbb", "metadata": {}, "outputs": [], "source": [ "# Function to get the average of some list of reference vectors\n", "def get_average_vector(idxs):\n", " reformatted_idxs = [\"_\".join(map(str, idx)) for idx in idxs]\n", " matching_rows = [\n", " tbl.to_pandas().query(f\"idx == '{idx}'\") for idx in reformatted_idxs\n", " ]\n", " matching_vectors = [row.iloc[0][\"vector\"] for row in matching_rows]\n", " vector_mean = numpy.mean(matching_vectors, axis=0)\n", " return vector_mean" ] }, { "cell_type": "markdown", "id": "c094f4c1-22c5-4610-b84b-27153251f723", "metadata": {}, "source": [ "Let's remind ourselves which patches have clouds in them." ] }, { "cell_type": "code", "execution_count": null, "id": "94607dbf-c61f-4ffe-b3d6-d18f26fdacf5", "metadata": {}, "outputs": [], "source": [ "cloud_threshold = 80" ] }, { "cell_type": "code", "execution_count": null, "id": "1191aab5-5348-4650-a09a-55e6b2c3140d", "metadata": {}, "outputs": [], "source": [ "# Get indices for patches where cloud percentages\n", "# exceed some interesting threshold\n", "cloudy_indices = []\n", "for key, value in cloud_pcts.items():\n", " if value > cloud_threshold:\n", " print(f\"Chunk {key}: Cloud percentage = {value}\")\n", " cloudy_indices.append(key)" ] }, { "cell_type": "code", "execution_count": null, "id": "04cc3b3d-9b71-4234-8de8-a118ed9f0f55", "metadata": {}, "outputs": [], "source": [ "v_cloudy = get_average_vector(cloudy_indices)" ] }, { "cell_type": "code", "execution_count": null, "id": "37a66637-c8bf-4ca7-96bd-8189a438f83b", "metadata": {}, "outputs": [], "source": [ "result_cloudy = tbl.search(query=v_cloudy).limit(10).to_pandas()" ] }, { "cell_type": "markdown", "id": "ae669948-b42c-499b-8fdd-c47846c863aa", "metadata": {}, "source": [ "Now let's set up a filtered search for patches that have very little cloud coverage." ] }, { "cell_type": "code", "execution_count": null, "id": "7f000eec-6a74-4f51-8f88-eb5108dcb46a", "metadata": {}, "outputs": [], "source": [ "cloud_threshold = 10" ] }, { "cell_type": "code", "execution_count": null, "id": "b3211745-fbbe-4c97-8ba6-00a54ea3491c", "metadata": { "scrolled": true }, "outputs": [], "source": [ "# Get indices for patches where cloud percentages\n", "# do not exceed some interesting threshold\n", "non_cloudy_indices = []\n", "for key, value in cloud_pcts.items():\n", " if value < cloud_threshold:\n", " # print(f\"Chunk {key}: Cloud percentage = {value}\")\n", " non_cloudy_indices.append(key)" ] }, { "cell_type": "code", "execution_count": null, "id": "3780012f-8cfe-4896-a872-a322fbc38237", "metadata": { "scrolled": true }, "outputs": [], "source": [ "v_non_cloudy = get_average_vector(non_cloudy_indices)" ] }, { "cell_type": "code", "execution_count": null, "id": "a085c025-c9e9-4f26-9333-558bc3891c5e", "metadata": {}, "outputs": [], "source": [ "result_non_cloudy = tbl.search(query=v_non_cloudy).limit(10).to_pandas()" ] }, { "cell_type": "markdown", "id": "613895cc-4d6d-4900-825a-f329361a6505", "metadata": {}, "source": [ "# Plot similar patches" ] }, { "cell_type": "code", "execution_count": null, "id": "23691ff2-3fa8-47e9-9eb9-bfe0731d32b6", "metadata": {}, "outputs": [], "source": [ "def plot(df, cols=10):\n", " fig, axs = plt.subplots(1, cols, figsize=(20, 10))\n", "\n", " row_0 = df.iloc[0]\n", " path = row_0[\"path\"]\n", " chip = rasterio.open(path)\n", " tile = row_0[\"tile\"]\n", " width = chip.width\n", " height = chip.height\n", " # Define the window size\n", " window_size = (32, 32)\n", "\n", " idxs_windows = {\"idx\": [], \"window\": []}\n", "\n", " # Iterate over the image in 32x32 windows\n", " for col in range(0, width, window_size[0]):\n", " for row in range(0, height, window_size[1]):\n", " # Define the window\n", " window = ((row, row + window_size[1]), (col, col + window_size[0]))\n", "\n", " # Read the data within the window\n", " data = chip.read(window=window)\n", "\n", " # Get the index of the window\n", " index = (col // window_size[0], row // window_size[1])\n", "\n", " # Process the window data here\n", " # For example, print the index and the shape of the window data\n", " # print(\"Index:\", index)\n", " # print(\"Window Shape:\", data.shape)\n", "\n", " idxs_windows[\"idx\"].append(\"_\".join(map(str, index)))\n", " idxs_windows[\"window\"].append(data)\n", "\n", " # print(idxs_windows)\n", "\n", " for ax, (_, row) in zip(axs.flatten(), df.iterrows()):\n", " idx = row[\"idx\"]\n", " # Find the corresponding window based on the idx\n", " window_index = idxs_windows[\"idx\"].index(idx)\n", " window_data = idxs_windows[\"window\"][window_index]\n", " # print(window_data.shape)\n", " subset_img = numpy.clip(\n", " (window_data.transpose(1, 2, 0)[:, :, :3] / 10_000) * 3, 0, 1\n", " )\n", " ax.imshow(subset_img)\n", " ax.set_title(f\"{tile}/{idx}/{row.cloud_cover}\")\n", " ax.set_axis_off()\n", " plt.tight_layout()\n", " fig.savefig(\"similar.png\")" ] }, { "cell_type": "markdown", "id": "e7eb6b08-65cd-4479-8a7f-c2a54b2ebfcb", "metadata": {}, "source": [ "#### Result from searching for cloudy samples" ] }, { "cell_type": "code", "execution_count": null, "id": "3ce75d35-749c-41ad-9dcb-57ee030082b3", "metadata": {}, "outputs": [], "source": [ "plot(result_cloudy)" ] }, { "cell_type": "markdown", "id": "4d506e8e-0519-4850-b9a0-1d7f095ece94", "metadata": {}, "source": [ "#### Result from searching for non-cloudy samples" ] }, { "cell_type": "code", "execution_count": null, "id": "020660f9-492e-4142-91ac-9c78215b5662", "metadata": {}, "outputs": [], "source": [ "plot(result_non_cloudy)" ] }, { "cell_type": "markdown", "id": "32f5c04c-10a6-4882-9123-61da69422cd0", "metadata": {}, "source": [ "#### Visualize the area of interest with the cloudy and non-cloudy patch results" ] }, { "cell_type": "code", "execution_count": null, "id": "a8db3487-21e2-4865-84e7-90316cd501f3", "metadata": {}, "outputs": [], "source": [ "# Make geodataframe of the search results\n", "# cloudy\n", "result_cloudy_boxes = [\n", " Polygon(\n", " [(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]\n", " )\n", " for bbox in result_cloudy[\"box\"]\n", "]\n", "result_cloudy_gdf = gpd.GeoDataFrame(result_cloudy, geometry=result_cloudy_boxes)\n", "result_cloudy_gdf.crs = \"EPSG:4326\"\n", "# non-cloudy\n", "result_non_cloudy_boxes = [\n", " Polygon(\n", " [(bbox[0], bbox[1]), (bbox[2], bbox[1]), (bbox[2], bbox[3]), (bbox[0], bbox[3])]\n", " )\n", " for bbox in result_non_cloudy[\"box\"]\n", "]\n", "result_non_cloudy_gdf = gpd.GeoDataFrame(\n", " result_non_cloudy, geometry=result_non_cloudy_boxes\n", ")\n", "result_non_cloudy_gdf.crs = \"EPSG:4326\"\n", "\n", "# Plot the AOI in RGB\n", "plot = stack_L2A.sel(band=[\"B04\", \"B03\", \"B02\"]).plot\n", "plot.imshow(row=\"time\", rgb=\"band\", vmin=0, vmax=2000)\n", "\n", "# Overlay the bounding boxes of the patches identified from the similarity search\n", "result_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color=\"red\", alpha=0.5)\n", "result_non_cloudy_gdf.to_crs(epsg).plot(ax=plt.gca(), color=\"blue\", alpha=0.5)\n", "\n", "\n", "# Set plot title and labels\n", "plt.title(\"Sentinel-2 with cloudy and non-cloudy embeddings\")\n", "plt.xlabel(\"Longitude\")\n", "plt.ylabel(\"Latitude\")\n", "\n", "# Show the plot\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "5d439d49-0806-4c4f-ae22-1f202d5e8ede", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "jupytext": { "formats": "ipynb,py:percent" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.13" } }, "nbformat": 4, "nbformat_minor": 5 }