{ "cells": [ { "cell_type": "markdown", "id": "e66cec10-75b7-4a14-b9b4-91f4d707bb6d", "metadata": {}, "source": [ "# CLAY v0 - Location Embeddings" ] }, { "cell_type": "markdown", "id": "65130e67-0868-4e6e-b181-4c456223f998", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "id": "ea0176a6-97a1-4af6-af75-b9e52e52fbaf", "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append(\"../\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5ea314d0-176a-4ee3-b738-6152d27275d9", "metadata": {}, "outputs": [], "source": [ "import warnings\n", "from pathlib import Path\n", "\n", "import lightning as L\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import rasterio as rio\n", "import torch\n", "from sklearn.cluster import KMeans\n", "from sklearn.decomposition import PCA\n", "\n", "from src.datamodule import ClayDataModule, ClayDataset\n", "from src.model_clay import CLAYModule\n", "\n", "warnings.filterwarnings(\"ignore\")" ] }, { "cell_type": "code", "execution_count": null, "id": "65b3babb-ab89-40f1-920c-ac9b88dc9738", "metadata": {}, "outputs": [], "source": [ "L.seed_everything(42)" ] }, { "cell_type": "code", "execution_count": null, "id": "b12b28e1-c8c6-470d-8b38-5600e4897074", "metadata": {}, "outputs": [], "source": [ "# data directory for all chips\n", "DATA_DIR = \"../data/02\"\n", "# path of best model checkpoint for Clay v0\n", "CKPT_PATH = \"../checkpoints/v0/mae_epoch-24_val-loss-0.46.ckpt\"" ] }, { "cell_type": "markdown", "id": "7bade738-5ed9-4530-a9a6-8ade3cb4d6d8", "metadata": {}, "source": [ "## Load Model & DataModule" ] }, { "cell_type": "code", "execution_count": null, "id": "4c5f2abf-5e9c-4def-88d9-38136307b420", "metadata": {}, "outputs": [], "source": [ "# Load the model & set in eval mode\n", "model = CLAYModule.load_from_checkpoint(CKPT_PATH, mask_ratio=0.7)\n", "model.eval();" ] }, { "cell_type": "code", "execution_count": null, "id": "348c0573-7670-47a6-9e13-c6de36493b58", "metadata": {}, "outputs": [], "source": [ "data_dir = Path(DATA_DIR)\n", "\n", "# Load the Clay DataModule\n", "ds = ClayDataset(chips_path=list(data_dir.glob(\"**/*.tif\")))\n", "dm = ClayDataModule(data_dir=str(data_dir), batch_size=100)\n", "dm.setup(stage=\"fit\")\n", "\n", "# Load the train DataLoader\n", "trn_dl = iter(dm.train_dataloader())" ] }, { "cell_type": "code", "execution_count": null, "id": "0fd6d1ca-4cde-48c9-842a-8eb16279a534", "metadata": {}, "outputs": [], "source": [ "# Load the first batch of chips\n", "batch = next(trn_dl)\n", "batch.keys()" ] }, { "cell_type": "code", "execution_count": null, "id": "f5e6dfc4-57fe-4296-9a57-9bd384ec61af", "metadata": {}, "outputs": [], "source": [ "# Save a copy of batch to visualize later\n", "_batch = batch[\"pixels\"].detach().clone().cpu().numpy()" ] }, { "cell_type": "markdown", "id": "db087380-c94c-40a2-8746-2e33b7903b3d", "metadata": {}, "source": [ "## Pass model through the CLAY model" ] }, { "cell_type": "code", "execution_count": null, "id": "65424c5e-f68d-47b0-bf87-030d3cf9b7e8", "metadata": {}, "outputs": [], "source": [ "# Pass the pixels through the encoder & decoder of CLAY\n", "with torch.no_grad():\n", " # Move data from to the device of model\n", " batch[\"pixels\"] = batch[\"pixels\"].to(model.device)\n", " batch[\"timestep\"] = batch[\"timestep\"].to(model.device)\n", " batch[\"latlon\"] = batch[\"latlon\"].to(model.device)\n", "\n", " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", " (\n", " unmasked_patches,\n", " unmasked_indices,\n", " masked_indices,\n", " masked_matrix,\n", " ) = model.model.encoder(batch)\n", "\n", " # Pass the unmasked_patches through the decoder to reconstruct the pixel space\n", " pixels = model.model.decoder(unmasked_patches, unmasked_indices, masked_indices)" ] }, { "cell_type": "markdown", "id": "c601956c-d18c-4cdf-a4fe-0a30e4b06786", "metadata": {}, "source": [ "## Extract Location & Timestep Embeddings" ] }, { "cell_type": "markdown", "id": "a8160012-d0d8-4eb8-b87a-f5c02cd980d6", "metadata": {}, "source": [ "In CLAY, the encoder receives unmasked patches, latitude-longitude data, and timestep information. Notably, the last 2 embeddings from the encoder specifically represent the latitude-longitude and timestep embeddings." ] }, { "cell_type": "code", "execution_count": null, "id": "baa7adfb-3503-415b-a892-00edcfe1f127", "metadata": {}, "outputs": [], "source": [ "latlon_embeddings = unmasked_patches[:, -2, :].detach().cpu().numpy()\n", "time_embeddings = unmasked_patches[:, -1, :].detach().cpu().numpy()\n", "\n", "# Get normalized latlon that were input to the model\n", "latlon = batch[\"latlon\"].detach().cpu().numpy()" ] }, { "cell_type": "markdown", "id": "da8ce686-9982-4845-a171-2729365ead8e", "metadata": {}, "source": [ "We will just focus on location embeddings in this notebook" ] }, { "cell_type": "code", "execution_count": null, "id": "591d1d55-a335-4925-8f45-48c9584106d7", "metadata": {}, "outputs": [], "source": [ "latlon.shape, latlon_embeddings.shape" ] }, { "cell_type": "markdown", "id": "3384f479-ef84-420d-a4e9-e3b038f05497", "metadata": {}, "source": [ "> Latitude & Longitude map to 768 dimentional vector" ] }, { "cell_type": "markdown", "id": "9e419fc9-e7d3-49de-a8ea-72912c365510", "metadata": {}, "source": [ "## Preform PCA over the location embeddings to visualize them in 2 dimension" ] }, { "cell_type": "code", "execution_count": null, "id": "7b7fc190-25c2-47b6-9e0a-cf7bc9d558d7", "metadata": {}, "outputs": [], "source": [ "pca = PCA(n_components=2)\n", "latlon_embeddings = pca.fit_transform(latlon_embeddings)\n", "latlon_embeddings.shape" ] }, { "cell_type": "markdown", "id": "92c92119-282a-43ca-aad4-a01f2d7c278b", "metadata": {}, "source": [ "## Create clusters of normalized latlon & latlon embeddings to check if there are any learned patterns in them after training" ] }, { "cell_type": "markdown", "id": "6d5127bd-e8c1-4ddc-a135-4ddcafa974cf", "metadata": {}, "source": [ "Latlon Cluster" ] }, { "cell_type": "code", "execution_count": null, "id": "330ae1a0-f94f-4d97-8a3f-eb507773585b", "metadata": {}, "outputs": [], "source": [ "kmeans = KMeans(n_clusters=5)\n", "kmeans.fit_transform(latlon)\n", "latlon = np.column_stack((latlon, kmeans.labels_))" ] }, { "cell_type": "markdown", "id": "d753e7e2-3550-431d-bc65-f8e2da857be1", "metadata": {}, "source": [ "Latlon Embeddings Cluster" ] }, { "cell_type": "code", "execution_count": null, "id": "f95256c1-65f3-4653-80e9-9c7ad31eb475", "metadata": {}, "outputs": [], "source": [ "kmeans = KMeans(n_clusters=5)\n", "kmeans.fit_transform(latlon_embeddings)\n", "latlon_embeddings = np.column_stack((latlon_embeddings, kmeans.labels_))" ] }, { "cell_type": "code", "execution_count": null, "id": "670c84f6-8041-4643-8c80-e039f2aa4400", "metadata": {}, "outputs": [], "source": [ "latlon.shape, latlon_embeddings.shape" ] }, { "cell_type": "markdown", "id": "94e3739f-38c6-40cf-9457-904cd6c56324", "metadata": {}, "source": [ "> We are a third dimension to latlon & latlon embeddings with cluster labels" ] }, { "cell_type": "markdown", "id": "a259f9d0-05dc-4dab-bf49-c18a510d3d3a", "metadata": {}, "source": [ "## Plot latlon clusters" ] }, { "cell_type": "code", "execution_count": null, "id": "142a72ee-f88a-4cf6-bdd5-3e24fa7bd3aa", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(15, 15), dpi=80)\n", "plt.scatter(latlon[:, 0], latlon[:, 1], c=latlon[:, 2], label=\"Actual\", alpha=0.3)\n", "\n", "for i in range(100):\n", " txt = f\"{latlon[:,0][i]:.2f},{latlon[:, 1][i]:.2f}\"\n", " plt.annotate(txt, (latlon[:, 0][i] + 1e-5, latlon[:, 1][i] + 1e-5))" ] }, { "cell_type": "markdown", "id": "91e1a739-9536-4401-98ed-6285ff51b09a", "metadata": {}, "source": [ "> As we see in the scatter plot above, there is nothing unique about latlon that go into the model, they are cluster based on their change in longitude values above" ] }, { "cell_type": "markdown", "id": "df76148a-b05f-4f52-8b55-1ea5ecac9f84", "metadata": {}, "source": [ "## Plot latlon embeddings cluster" ] }, { "cell_type": "code", "execution_count": null, "id": "c5d51daa-1541-49a4-9b1d-285189136d34", "metadata": {}, "outputs": [], "source": [ "plt.figure(figsize=(15, 15), dpi=80)\n", "plt.scatter(\n", " latlon_embeddings[:, 0],\n", " latlon_embeddings[:, 1],\n", " c=latlon_embeddings[:, 2],\n", " label=\"Predicted\",\n", " alpha=0.3,\n", ")\n", "for i in range(100):\n", " txt = i\n", " plt.annotate(txt, (latlon_embeddings[:, 0][i], latlon_embeddings[:, 1][i]))" ] }, { "cell_type": "code", "execution_count": null, "id": "0ebbad29-4abf-4a4b-9c94-455edc8dac40", "metadata": {}, "outputs": [], "source": [ "def show_cluster(ids):\n", " fig, axes = plt.subplots(1, len(ids), figsize=(10, 5))\n", " for i, ax in zip(ids, axes.flatten()):\n", " img_path = batch[\"source_url\"][i]\n", " img = rio.open(img_path).read([3, 2, 1]).transpose(1, 2, 0)\n", " img = (img - img.min()) / (img.max() - img.min())\n", " ax.imshow(img)\n", " ax.set_axis_off()" ] }, { "cell_type": "code", "execution_count": null, "id": "797f9b1c-17db-45c3-abcd-e7b16886f591", "metadata": {}, "outputs": [], "source": [ "show_cluster((87, 37, 40))" ] }, { "cell_type": "code", "execution_count": null, "id": "5ae6b021-24e9-408b-a558-f322f4e7bc2d", "metadata": {}, "outputs": [], "source": [ "show_cluster((23, 11, 41))" ] }, { "cell_type": "code", "execution_count": null, "id": "224d210b-642c-4ad4-acc0-f1b009a7f0a5", "metadata": {}, "outputs": [], "source": [ "show_cluster((68, 71, 7))" ] }, { "cell_type": "markdown", "id": "7c8d1fde-fefc-4e4c-8518-f803411baa01", "metadata": {}, "source": [ "> We can see location embedding capturing semantic information as well" ] }, { "cell_type": "code", "execution_count": null, "id": "4910938d-611d-47a8-abc3-e94e3e152dd1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }