{ "cells": [ { "cell_type": "markdown", "id": "992c4084-df6f-4ba6-af58-83f1b7fcb7eb", "metadata": {}, "source": [ "# Let us visualize how Clay reconstructs each sensor from 70% masked inputs" ] }, { "cell_type": "code", "execution_count": 1, "id": "e0c141b9-4038-4542-832c-f71e04bd93c1", "metadata": {}, "outputs": [], "source": [ "import sys\n", "import warnings\n", "\n", "sys.path.append(\"..\")\n", "warnings.filterwarnings(action=\"ignore\")" ] }, { "cell_type": "code", "execution_count": 2, "id": "caedb33f-103b-4114-8029-6e0953dbf7aa", "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from einops import rearrange\n", "\n", "from src.datamodule import ClayDataModule\n", "from src.model import ClayMAEModule" ] }, { "cell_type": "code", "execution_count": 3, "id": "08f83855-0cac-4a98-979b-177b5e96c325", "metadata": {}, "outputs": [], "source": [ "DATA_DIR = \"/home/ubuntu/data\"\n", "CHECKPOINT_PATH = \"../checkpoints/v0.5.7/mae_v0.5.7_epoch-13_val-loss-0.3098.ckpt\"\n", "METADATA_PATH = \"../configs/metadata.yaml\"\n", "CHIP_SIZE = 224\n", "MASK_RATIO = 0.7 # 70% masking of inputs\n", "DEVICE = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")" ] }, { "cell_type": "markdown", "id": "49ebc53b-1a87-407c-9378-aa2cb8cd274d", "metadata": {}, "source": [ "### MODEL\n", "\n", "Load the model with best checkpoint path and set it in eval mode. Set mask_ratio to `70%` & shuffle to `True`." ] }, { "cell_type": "code", "execution_count": 4, "id": "4016a37e-d80f-4a62-afac-bc1ac2d67ce5", "metadata": {}, "outputs": [], "source": [ "module = ClayMAEModule.load_from_checkpoint(\n", " checkpoint_path=CHECKPOINT_PATH,\n", " metadata_path=METADATA_PATH,\n", " mask_ratio=0.7,\n", " shuffle=True,\n", ")\n", "\n", "module.eval();" ] }, { "cell_type": "markdown", "id": "f8c57245-5d19-483c-8051-da413bb425df", "metadata": {}, "source": [ "## DATAMODULE\n", "\n", "Load the ClayDataModule" ] }, { "cell_type": "code", "execution_count": 5, "id": "3e594d84-86bd-4f1d-83a2-8174f475068c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Total number of chips: 193\n" ] } ], "source": [ "# For model training, we stack chips from one sensor into batches of size 128.\n", "# This reduces the num_workers we need to load the batches and speeds up the\n", "# training process. Here, although the batch size is 1, the data module reads\n", "# batch of size 128.\n", "dm = ClayDataModule(\n", " data_dir=DATA_DIR,\n", " metadata_path=METADATA_PATH,\n", " size=CHIP_SIZE,\n", " batch_size=1,\n", " num_workers=1,\n", ")\n", "dm.setup(stage=\"fit\")" ] }, { "cell_type": "markdown", "id": "b47b4eab-41e0-4acf-b8a9-96cd6fca7563", "metadata": {}, "source": [ "Let us look at the data directory.\n", "\n", "We have a folder for each sensor, i.e:\n", "- Landsat l1\n", "- Landsat l2\n", "- Sentinel 1 rtc\n", "- Sentinel 2 l2a\n", "- Naip\n", "- Linz\n", "\n", "And, under each folder, we have stacks of chips as `.npz` files." ] }, { "cell_type": "code", "execution_count": 6, "id": "4972a093-1360-46cd-8a6e-1956a8818ac7", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[01;34m/home/ubuntu/data\u001b[00m\n", "├── \u001b[01;34mlandsat-c2l1\u001b[00m\n", "├── \u001b[01;34mlandsat-c2l2-sr\u001b[00m\n", "├── \u001b[01;34mlinz\u001b[00m\n", "├── \u001b[01;34mnaip\u001b[00m\n", "├── \u001b[01;34msentinel-1-rtc\u001b[00m\n", "└── \u001b[01;34msentinel-2-l2a\u001b[00m\n", "\n", "6 directories, 0 files\n" ] } ], "source": [ "!tree -L 1 {DATA_DIR}" ] }, { "cell_type": "code", "execution_count": 7, "id": "1164f64d-400c-4505-95fa-77f820e29e15", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\u001b[01;34m/home/ubuntu/data/naip\u001b[00m\n", "├── cube_10.npz\n", "├── cube_100045.npz\n", "├── cube_100046.npz\n", "├── cube_100072.npz\n" ] } ], "source": [ "!tree -L 2 {DATA_DIR}/naip | head -5" ] }, { "cell_type": "markdown", "id": "ada5b6b6-f04a-4293-b6a2-46b1407c27bd", "metadata": {}, "source": [ "Now, lets look at what we have in each of the `.npz` files." ] }, { "cell_type": "code", "execution_count": 8, "id": "18546298-c291-46ed-b95d-48aae22353e6", "metadata": {}, "outputs": [], "source": [ "sample = np.load(\"/home/ubuntu/data/naip/cube_10.npz\")" ] }, { "cell_type": "code", "execution_count": 9, "id": "45235d8e-1048-4f60-af84-24e2af505daf", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "KeysView(NpzFile '/home/ubuntu/data/naip/cube_10.npz' with keys: pixels, lon_norm, lat_norm, week_norm, hour_norm)" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample.keys()" ] }, { "cell_type": "code", "execution_count": 10, "id": "aa31731c-628d-44ad-ad6f-d8bc4467978b", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "(128, 4, 256, 256)" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample[\"pixels\"].shape" ] }, { "cell_type": "code", "execution_count": 11, "id": "00c4f270-4b43-47eb-9609-c3609cf8caca", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((128, 2), (128, 2))" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample[\"lat_norm\"].shape, sample[\"lon_norm\"].shape" ] }, { "cell_type": "code", "execution_count": 12, "id": "7926d079-e49d-44a7-9fe6-674d5a4abf7e", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "((128, 2), (128, 2))" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sample[\"week_norm\"].shape, sample[\"hour_norm\"].shape" ] }, { "cell_type": "markdown", "id": "bcd51378-36c5-4a69-9bbc-944ba31bd532", "metadata": {}, "source": [ "As we see above, chips are stacked in batches of size `128`. \n", "The sample we are looking at is from `NAIP` so it has 4 bands & of size `256 x 256`. \n", "We also get normalized lat/lon & timestep (hour/week) information that is *(optionally required) by the model. If you don't have this handy, feel free to pass zero tensors in their place." ] }, { "cell_type": "markdown", "id": "7d13d854-ed5e-4425-b41f-2e9f3372a44c", "metadata": {}, "source": [ "Load a batch of data from ClayDataModule\n", "\n", "ClayDataModule is designed to fetch random batches of data from different sensors sequentially, i.e batches are in ascending order of their directory - Landsat 1, Landsat 2, LINZ, NAIP, Sentinel 1 rtc, Sentinel 2 L2A and it repeats after that. " ] }, { "cell_type": "code", "execution_count": 13, "id": "c299e445-9b45-4483-b254-fc67b5a85e1e", "metadata": {}, "outputs": [], "source": [ "# We have a random sample subset of the data, so it's\n", "# okay to use either the train or val dataloader.\n", "dl = iter(dm.train_dataloader())" ] }, { "cell_type": "code", "execution_count": 14, "id": "7fe2ace3-06df-4768-8fa2-d51e2e9a02b8", "metadata": {}, "outputs": [], "source": [ "l1 = next(dl)\n", "l2 = next(dl)\n", "linz = next(dl)\n", "naip = next(dl)\n", "s1 = next(dl)\n", "s2 = next(dl)" ] }, { "cell_type": "code", "execution_count": 15, "id": "25e895ec-c22b-4589-a0b1-2394c1aaeeac", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "landsat-c2l1 torch.Size([128, 6, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])\n", "landsat-c2l2-sr torch.Size([128, 6, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])\n", "linz torch.Size([128, 3, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])\n", "naip torch.Size([128, 4, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])\n", "sentinel-1-rtc torch.Size([128, 2, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])\n", "sentinel-2-l2a torch.Size([128, 10, 224, 224]) torch.Size([128, 4]) torch.Size([128, 4])\n" ] } ], "source": [ "for sensor, chips in zip(\n", " (\"l1\", \"l2\", \"linz\", \"naip\", \"s1\", \"s2\"), (l1, l2, linz, naip, s1, s2)\n", "):\n", " print(\n", " f\"{chips['platform'][0]:<15}\",\n", " chips[\"pixels\"].shape,\n", " chips[\"time\"].shape,\n", " chips[\"latlon\"].shape,\n", " )" ] }, { "cell_type": "markdown", "id": "45b338dc-00aa-4bdd-bb50-8d067b05e34b", "metadata": {}, "source": [ "## INPUT\n", "\n", "Model expects a dictionary with keys:\n", "- pixels: `batch x band x height x width` - normalized chips of a sensor\n", "- time: `batch x 4` - horizontally stacked `week_norm` & `hour_norm`\n", "- latlon: `batch x 4` - horizontally stacked `lat_norm` & `lon_norm`\n", "- waves: `list[:band]` - wavelengths of each band of the sensor from the `metadata.yaml` file\n", "- gsd: `scalar` - gsd of the sensor from `metadata.yaml` file\n", "\n", "Normalization & stacking is taken care of by the ClayDataModule: https://github.com/Clay-foundation/model/blob/f872f098224d64677ed96b6a49974bb7ddef10dc/src/datamodule.py#L55-L72\n", "\n", "When not using the ClayDataModule, make sure you normalize the chips & pass all items for the model. " ] }, { "cell_type": "code", "execution_count": 16, "id": "e88c21ef-e104-43f0-a343-e28eea79fc69", "metadata": {}, "outputs": [], "source": [ "def create_batch(chips, wavelengths, gsd, device):\n", " batch = {}\n", "\n", " batch[\"pixels\"] = chips[\"pixels\"].to(device)\n", " batch[\"time\"] = chips[\"time\"].to(device)\n", " batch[\"latlon\"] = chips[\"latlon\"].to(device)\n", "\n", " batch[\"waves\"] = torch.tensor(wavelengths)\n", " batch[\"gsd\"] = torch.tensor(gsd)\n", "\n", " return batch" ] }, { "cell_type": "markdown", "id": "f2e16900-4ebf-4ecf-a7f5-0e4420617a4d", "metadata": {}, "source": [ "### FORWARD PASS\n", "\n", "Pass a batch from single sensor through the Encoder & Decoder of Clay." ] }, { "cell_type": "code", "execution_count": 17, "id": "02c012f8-f91c-4a68-887a-ea0ef4cdcd76", "metadata": {}, "outputs": [], "source": [ "# Let us see an example of what input looks like for NAIP\n", "platform = \"naip\"\n", "metadata = dm.metadata[platform]\n", "wavelengths = list(metadata.bands.wavelength.values())\n", "gsd = metadata.gsd\n", "batch_naip = create_batch(naip, wavelengths, gsd, DEVICE)" ] }, { "cell_type": "code", "execution_count": 18, "id": "10f18655-a3eb-4298-a205-bf6a0d14535d", "metadata": {}, "outputs": [], "source": [ "with torch.no_grad():\n", " unmsk_patch, unmsk_idx, msk_idx, msk_matrix = module.model.encoder(batch_naip)\n", " pixels_naip, _ = module.model.decoder(\n", " unmsk_patch,\n", " unmsk_idx,\n", " msk_idx,\n", " msk_matrix,\n", " batch_naip[\"time\"],\n", " batch_naip[\"latlon\"],\n", " batch_naip[\"gsd\"],\n", " batch_naip[\"waves\"],\n", " )" ] }, { "cell_type": "code", "execution_count": 19, "id": "e6f31bae-6447-492c-bb1d-d105f52d02db", "metadata": {}, "outputs": [], "source": [ "def denormalize_images(normalized_images, means, stds):\n", " means = np.array(means)\n", " stds = np.array(stds)\n", " means = means.reshape(1, -1, 1, 1)\n", " stds = stds.reshape(1, -1, 1, 1)\n", " denormalized_images = normalized_images * stds + means\n", "\n", " return denormalized_images" ] }, { "cell_type": "code", "execution_count": 20, "id": "b3e19810-3f73-4db1-aca4-6bbae490ba83", "metadata": {}, "outputs": [], "source": [ "naip_mean = list(metadata.bands.mean.values())\n", "naip_std = list(metadata.bands.std.values())\n", "\n", "batch_naip_pixels = batch_naip[\"pixels\"].detach().cpu().numpy()\n", "batch_naip_pixels = denormalize_images(batch_naip_pixels, naip_mean, naip_std)\n", "batch_naip_pixels = batch_naip_pixels.astype(np.uint8)" ] }, { "cell_type": "markdown", "id": "2310413c-6498-4c03-826e-694886aec101", "metadata": {}, "source": [ "Rearrange the patches from the decoder back into pixel space.\n", "\n", "Output from Clay decoder is of shape: `batch x patches x pixel_values_per_patch`, we will reshape that to `batch x channel x height x width`" ] }, { "cell_type": "code", "execution_count": 21, "id": "0f97fd1b-d3d6-4137-bc41-b7209cd3b01c", "metadata": {}, "outputs": [], "source": [ "pixels_naip = rearrange(\n", " pixels_naip,\n", " \"b (h w) (c p1 p2) -> b c (h p1) (w p2)\",\n", " c=batch_naip_pixels.shape[1],\n", " h=28,\n", " w=28,\n", " p1=8,\n", " p2=8,\n", ")" ] }, { "cell_type": "code", "execution_count": 22, "id": "bc0b2c53-0e55-46e3-b83a-f847fe8a10f4", "metadata": {}, "outputs": [], "source": [ "denorm_pixels_naip = denormalize_images(\n", " pixels_naip.cpu().detach().numpy(), naip_mean, naip_std\n", ")\n", "denorm_pixels_naip = denorm_pixels_naip.astype(np.uint8)" ] }, { "cell_type": "code", "execution_count": null, "id": "32e734a3-bb9c-4026-a442-da65d3f8abe3", "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(6, 8, figsize=(20, 14))\n", "\n", "for i in range(0, 6, 2):\n", " for j in range(8):\n", " idx = (i // 2) * 8 + j\n", " axs[i][j].imshow(batch_naip_pixels[idx, [0, 1, 2], ...].transpose(1, 2, 0))\n", " axs[i][j].set_axis_off()\n", " axs[i][j].set_title(f\"Actual {idx}\")\n", " axs[i + 1][j].imshow(denorm_pixels_naip[idx, [0, 1, 2], ...].transpose(1, 2, 0))\n", " axs[i + 1][j].set_axis_off()\n", " axs[i + 1][j].set_title(f\"Rec {idx}\")" ] }, { "cell_type": "markdown", "id": "e92899f3", "metadata": {}, "source": [ "![Minicube visualization](https://github.com/Clay-foundation/model/assets/901647/09bb949f-da73-4d66-91ea-8d6f8777b62a)" ] }, { "cell_type": "markdown", "id": "e61236a8-8a1a-4a21-9780-b4e1e3c13938", "metadata": {}, "source": [ "### Next steps\n", "\n", "- Look at reconstructions from other sensors\n", "- Interpolate between embedding spaces & reconstruct images from there (create a video animation)" ] }, { "cell_type": "code", "execution_count": null, "id": "cbd54da5-72ae-4ac7-a234-01d874d1cd35", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.12" } }, "nbformat": 4, "nbformat_minor": 5 }