{ "cells": [ { "cell_type": "markdown", "id": "76ed0078-447f-4374-b6ba-a8b4a366188d", "metadata": {}, "source": [ "# CLAY v0 - Interpolation between images" ] }, { "cell_type": "code", "execution_count": null, "id": "ea0176a6-97a1-4af6-af75-b9e52e52fbaf", "metadata": {}, "outputs": [], "source": [ "import sys\n", "\n", "sys.path.append(\"../\")" ] }, { "cell_type": "code", "execution_count": null, "id": "5ea314d0-176a-4ee3-b738-6152d27275d9", "metadata": {}, "outputs": [], "source": [ "import os\n", "from pathlib import Path\n", "\n", "import imageio\n", "import matplotlib.pyplot as plt\n", "import numpy as np\n", "import torch\n", "from einops import rearrange\n", "from PIL import Image\n", "\n", "from src.datamodule import ClayDataModule, ClayDataset\n", "from src.model_clay import CLAYModule" ] }, { "cell_type": "code", "execution_count": null, "id": "37f4a735-18e6-48d7-9b58-e8d188e96b54", "metadata": {}, "outputs": [], "source": [ "# data directory for all chips\n", "DATA_DIR = \"../data/02\"\n", "# path of best model checkpoint for Clay v0\n", "CKPT_PATH = \"https://huggingface.co/made-with-clay/Clay/resolve/main/Clay_v0.1_epoch-24_val-loss-0.46.ckpt\"" ] }, { "cell_type": "markdown", "id": "4c300730-b0b0-4c3d-8a0d-d5e3ac018641", "metadata": {}, "source": [ "## Load Model & DataModule" ] }, { "cell_type": "code", "execution_count": null, "id": "4c5f2abf-5e9c-4def-88d9-38136307b420", "metadata": {}, "outputs": [], "source": [ "# Load the model & set in eval mode\n", "model = CLAYModule.load_from_checkpoint(\n", " CKPT_PATH, mask_ratio=0.0, shuffle=False\n", ") # No masking or shuffling of patches\n", "model.eval();" ] }, { "cell_type": "code", "execution_count": null, "id": "348c0573-7670-47a6-9e13-c6de36493b58", "metadata": {}, "outputs": [], "source": [ "data_dir = Path(DATA_DIR)\n", "\n", "# Load the Clay DataModule\n", "ds = ClayDataset(chips_path=list(data_dir.glob(\"**/*.tif\")))\n", "dm = ClayDataModule(data_dir=str(data_dir), batch_size=2)\n", "dm.setup(stage=\"fit\")\n", "\n", "# Load the train DataLoader\n", "trn_dl = iter(dm.train_dataloader())" ] }, { "cell_type": "code", "execution_count": null, "id": "af3b38f7-f193-4b7d-aada-b4c8abcce7ed", "metadata": {}, "outputs": [], "source": [ "# Load the first batch of chips\n", "batch = next(trn_dl)\n", "batch.keys()" ] }, { "cell_type": "code", "execution_count": null, "id": "59f9a028-a789-40c9-8f23-eb2a0e1c66eb", "metadata": {}, "outputs": [], "source": [ "batch[\"pixels\"].shape, batch[\"latlon\"].shape, batch[\"timestep\"].shape" ] }, { "cell_type": "code", "execution_count": null, "id": "5e4c53bf-43a0-4e19-93da-bc47bb446e74", "metadata": {}, "outputs": [], "source": [ "def show(sample, idx=None, save=False):\n", " Path(\"animate\").mkdir(exist_ok=True)\n", " sample = rearrange(sample, \"c h w -> h w c\")\n", " denorm_sample = sample * torch.as_tensor(dm.STD) + torch.as_tensor(dm.MEAN)\n", " rgb = denorm_sample[..., [2, 1, 0]]\n", " plt.imshow((rgb - rgb.min()) / (rgb.max() - rgb.min()))\n", " plt.axis(\"off\")\n", " if save:\n", " plt.savefig(f\"animate/chip_{idx}.png\")" ] }, { "cell_type": "code", "execution_count": null, "id": "12e22e6b-127e-4948-b3e1-12b81164a417", "metadata": {}, "outputs": [], "source": [ "sample1, sample2 = batch[\"pixels\"]" ] }, { "cell_type": "code", "execution_count": null, "id": "9648f698-7970-4f42-a9d5-093661139994", "metadata": {}, "outputs": [], "source": [ "show(sample1)" ] }, { "cell_type": "code", "execution_count": null, "id": "cbd75d9a-b42a-4e92-b81c-705a98a07d41", "metadata": {}, "outputs": [], "source": [ "show(sample2)" ] }, { "cell_type": "markdown", "id": "4d52f962-d20d-4966-af4d-fd5509520356", "metadata": {}, "source": [ "Each batch has chips of shape `13 x 512 x 512`, normalized `lat` & `lon` coords & normalized timestep information as `year`, `month` & `day`." ] }, { "cell_type": "code", "execution_count": null, "id": "f5e6dfc4-57fe-4296-9a57-9bd384ec61af", "metadata": {}, "outputs": [], "source": [ "# Save a copy of batch to visualize later\n", "_batch = batch[\"pixels\"].detach().clone().cpu().numpy()" ] }, { "cell_type": "markdown", "id": "e7c46bb7-3e25-454d-b345-1ca3ec1efb69", "metadata": {}, "source": [ "## Pass data through the CLAY model" ] }, { "cell_type": "code", "execution_count": null, "id": "092d1ded-427f-424f-82ec-63bf0bccfdcc", "metadata": {}, "outputs": [], "source": [ "# Pass the pixels through the encoder & decoder of CLAY\n", "with torch.no_grad():\n", " # Move data from to the device of model\n", " batch[\"pixels\"] = batch[\"pixels\"].to(model.device)\n", " batch[\"timestep\"] = batch[\"timestep\"].to(model.device)\n", " batch[\"latlon\"] = batch[\"latlon\"].to(model.device)\n", "\n", " # Pass pixels, latlon, timestep through the encoder to create encoded patches\n", " (\n", " unmasked_patches,\n", " unmasked_indices,\n", " masked_indices,\n", " masked_matrix,\n", " ) = model.model.encoder(batch)" ] }, { "cell_type": "markdown", "id": "5c14323c-dc02-4255-aa1a-3ec3c8799e3f", "metadata": {}, "source": [ "### Create an image based on interpolation of the embedding values between 2 images\n", "*Images are saved inside `./animate`*" ] }, { "cell_type": "code", "execution_count": null, "id": "6386a7a1-d225-4c1e-8228-23f5ce4b87e4", "metadata": {}, "outputs": [], "source": [ "for idx, alpha in enumerate(np.linspace(0, 1, 20)):\n", " patch_break = 128\n", " l1, l2 = unmasked_patches\n", " l3 = alpha * l1 + (1 - alpha) * l2\n", " l4 = torch.vstack((l1[:patch_break, :], l2[patch_break:, :]))\n", "\n", " # Pass the unmasked_patches through the decoder to reconstruct the pixel space\n", " with torch.no_grad():\n", " pixels = model.model.decoder(\n", " rearrange(l3, \"gl d -> 1 gl d\"), unmasked_indices[[0]], masked_indices[[0]]\n", " )\n", "\n", " image = rearrange(pixels, \"b c (h w) (p1 p2) -> b c (h p1) (w p2)\", h=16, p1=32)\n", " _image = image[0].detach().cpu()\n", " show(_image, idx, save=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "ea7d8627-252d-42e2-af29-0fccb611121e", "metadata": {}, "outputs": [], "source": [ "fig, axs = plt.subplots(2, 10, figsize=(20, 4))\n", "for ax, idx in zip(axs.flatten(), range(20)):\n", " ax.imshow(Image.open(f\"./animate/chip_{idx}.png\"))\n", " ax.set_title(f\"Seq {idx}\")\n", " ax.set_axis_off()\n", "plt.tight_layout()" ] }, { "cell_type": "markdown", "id": "91606fa2-deec-485a-815c-d8c7ba07dec3", "metadata": {}, "source": [ "#### Create a GIF of the interpolation of images" ] }, { "cell_type": "code", "execution_count": null, "id": "91d9aeb7-edc3-492d-a6f1-fd5309e2ab40", "metadata": {}, "outputs": [], "source": [ "img_paths = [f\"./animate/chip_{idx}.png\" for idx in range(20)]\n", "\n", "with imageio.get_writer(\"animate/sample.gif\", mode=\"I\", duration=100) as writer:\n", " for img_path in img_paths:\n", " img = imageio.imread(img_path)\n", " writer.append_data(img)\n", "\n", "# Delete the images\n", "for img_path in img_paths:\n", " os.remove(img_path)" ] }, { "cell_type": "code", "execution_count": null, "id": "0e458cd3-d1a8-41b2-9e85-7612c153a7ea", "metadata": {}, "outputs": [], "source": [ "from IPython.display import Image, display\n", "\n", "display(Image(filename=\"./animate/sample.gif\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "730bbb55-344d-49e4-a81a-59a9da014770", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.7" } }, "nbformat": 4, "nbformat_minor": 5 }