{ "cells": [ { "cell_type": "markdown", "id": "0cc5e729-9116-4ec9-bf1e-8346cbccdf7b", "metadata": {}, "source": [ "## Run Clay v1\n", "\n", "This notebook shows how to run Clay v1 wall-to-wall, from downloading imagery\n", "to training a tiny fine tuning head. This will include the following steps:\n", "\n", "1. Set a location and date range of interest\n", "2. Download Sentinel-2 imagery for this specification\n", "3. Load the model checkpoint\n", "4. Prepare data into a format for the model\n", "5. Run the model on the imagery\n", "6. Analyise the model embeddings output using PCA\n", "7. Train a Support Vector Machines fine tuning head" ] }, { "cell_type": "code", "execution_count": 1, "id": "add63cd9", "metadata": {}, "outputs": [], "source": [ "# Add the repo root to the sys path for the model import below\n", "import sys\n", "\n", "sys.path.append(\"..\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "6a17b8a8-a9c6-4053-833e-de97287fae49", "metadata": {}, "outputs": [], "source": [ "import math\n", "\n", "import geopandas as gpd\n", "import numpy as np\n", "import pandas as pd\n", "import pystac_client\n", "import stackstac\n", "import torch\n", "import yaml\n", "from box import Box\n", "from matplotlib import pyplot as plt\n", "from rasterio.enums import Resampling\n", "from shapely import Point\n", "from sklearn import decomposition, svm\n", "from torchvision.transforms import v2\n", "\n", "from src.model import ClayMAEModule" ] }, { "cell_type": "markdown", "id": "beac6394-9762-422b-9f5d-82d226018c0c", "metadata": {}, "source": [ "### Specify location and date of interest\n", "In this example we will use a location in Portugal where a forest fire happened. We will run the model over the time period of the fire and analyse the model embeddings." ] }, { "cell_type": "code", "execution_count": 3, "id": "08d7787d-1506-4de7-89dc-c1054910acf7", "metadata": {}, "outputs": [], "source": [ "# Point over Monchique Portugal\n", "lat, lon = 37.30939, -8.57207\n", "\n", "# Dates of a large forest fire\n", "start = \"2018-07-01\"\n", "end = \"2018-09-01\"" ] }, { "cell_type": "markdown", "id": "2bd226c9-003b-4867-a64a-8ae887e7e20a", "metadata": {}, "source": [ "### Get data from STAC catalog\n", "\n", "Based on the location and date we can obtain a stack of imagery using stackstac. Let's start with finding the STAC items we want to analyse." ] }, { "cell_type": "code", "execution_count": 4, "id": "2e80743c-7c77-459b-9984-f6c26cdff549", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/tam/apps/miniforge3/envs/claymodel/lib/python3.11/site-packages/pystac_client/item_search.py:850: FutureWarning: get_all_items() is deprecated, use item_collection() instead.\n", " warnings.warn(\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Found 12 items\n" ] } ], "source": [ "STAC_API = \"https://earth-search.aws.element84.com/v1\"\n", "COLLECTION = \"sentinel-2-l2a\"\n", "\n", "# Search the catalogue\n", "catalog = pystac_client.Client.open(STAC_API)\n", "search = catalog.search(\n", " collections=[COLLECTION],\n", " datetime=f\"{start}/{end}\",\n", " bbox=(lon - 1e-5, lat - 1e-5, lon + 1e-5, lat + 1e-5),\n", " max_items=100,\n", " query={\"eo:cloud_cover\": {\"lt\": 80}},\n", ")\n", "\n", "all_items = search.get_all_items()\n", "\n", "# Reduce to one per date (there might be some duplicates\n", "# based on the location)\n", "items = []\n", "dates = []\n", "for item in all_items:\n", " if item.datetime.date() not in dates:\n", " items.append(item)\n", " dates.append(item.datetime.date())\n", "\n", "print(f\"Found {len(items)} items\")" ] }, { "cell_type": "markdown", "id": "5b7c68ae-7c8a-446a-8bc7-5afba70183c2", "metadata": {}, "source": [ "### Create a bounding box around the point of interest\n", "\n", "This is needed in the projection of the data so that we can generate image chips of the right size." ] }, { "cell_type": "code", "execution_count": 5, "id": "0f3573b5-5a00-47d9-a648-5c4d7cd2c996", "metadata": {}, "outputs": [], "source": [ "# Extract coordinate system from first item\n", "epsg = items[0].properties[\"proj:epsg\"]\n", "\n", "# Convert point of interest into the image projection\n", "# (assumes all images are in the same projection)\n", "poidf = gpd.GeoDataFrame(\n", " pd.DataFrame(),\n", " crs=\"EPSG:4326\",\n", " geometry=[Point(lon, lat)],\n", ").to_crs(epsg)\n", "\n", "coords = poidf.iloc[0].geometry.coords[0]\n", "\n", "# Create bounds in projection\n", "size = 256\n", "gsd = 10\n", "bounds = (\n", " coords[0] - (size * gsd) // 2,\n", " coords[1] - (size * gsd) // 2,\n", " coords[0] + (size * gsd) // 2,\n", " coords[1] + (size * gsd) // 2,\n", ")" ] }, { "cell_type": "markdown", "id": "bbbd3f67-5f2c-46dc-9ee1-2ef1f50fa032", "metadata": {}, "source": [ "### Retrieve the imagery data." ] }, { "cell_type": "code", "execution_count": 6, "id": "8b8d3824-e48c-4f9d-9c7b-181c0800f96f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ " Size: 13MB\n", "dask.array\n", "Coordinates: (12/53)\n", " * time (time) datetime64[ns] 96B 2018-0...\n", " id (time) " ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" }, { "data": { "image/png": "", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Run PCA\n", "pca = decomposition.PCA(n_components=1)\n", "pca_result = pca.fit_transform(embeddings)\n", "\n", "plt.xticks(rotation=-45)\n", "\n", "# Plot all points in blue first\n", "plt.scatter(stack.time, pca_result, color=\"blue\")\n", "\n", "# Re-plot cloudy images in green\n", "plt.scatter(stack.time[0], pca_result[0], color=\"green\")\n", "plt.scatter(stack.time[2], pca_result[2], color=\"green\")\n", "\n", "# Color all images after fire in red\n", "plt.scatter(stack.time[-5:], pca_result[-5:], color=\"red\")" ] }, { "cell_type": "markdown", "id": "b38b70a6-2156-41f8-967e-a490cc8e2778", "metadata": {}, "source": [ "### And finally, some finetuning\n", "\n", "We are going to train a classifier head on the embeddings and use it to detect fires." ] }, { "cell_type": "code", "execution_count": 14, "id": "1da07de0-b8f2-46c9-bd2a-58b15ca2224f", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Matched 5 out of 5 correctly\n" ] } ], "source": [ "# Label the images we downloaded\n", "# 0 = Cloud\n", "# 1 = Forest\n", "# 2 = Fire\n", "labels = np.array([0, 1, 0, 1, 1, 1, 1, 2, 2, 2, 2, 2])\n", "\n", "# Split into fit and test manually, ensuring we have all 3 classes in both sets\n", "fit = [0, 1, 3, 4, 7, 8, 9]\n", "test = [2, 5, 6, 10, 11]\n", "\n", "# Train a support vector machine model\n", "clf = svm.SVC()\n", "clf.fit(embeddings[fit] + 100, labels[fit])\n", "\n", "# Predict classes on test set\n", "prediction = clf.predict(embeddings[test] + 100)\n", "\n", "# Perfect match for SVM\n", "match = np.sum(labels[test] == prediction)\n", "print(f\"Matched {match} out of {len(test)} correctly\")" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.8" } }, "nbformat": 4, "nbformat_minor": 5 }