{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Train autoencoder for embedding extraction\n", "\n", "This notebook showcases more ways of working with metrics and embeddings in 3LC. It is mostly meant as a demonstration of how to collect embeddings and image metrics from a manually trained model. \n", "\n", "![](../images/autoencoder.png)\n", "\n", "\n", "\n", "The auto-encoder architecture is mainly used as a simple example to demonstrate the process, and the model should only be considered as an example of an embedding extractor, which also produces images as a side effect." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Install dependencies" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%pip install 3lc[pacmap]\n", "%pip install git+https://github.com/3lc-ai/3lc-examples.git\n", "%pip install timm" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import tlc\n", "import torch\n", "import torch.nn as nn\n", "from timm import create_model\n", "from torch.utils.data import DataLoader\n", "from torchvision import transforms\n", "from tqdm.auto import tqdm\n", "\n", "from tlc_tools.common import infer_torch_device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Project setup" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "parameters" ] }, "outputs": [], "source": [ "DOWNLOAD_PATH = \"../../transient_data\"\n", "PROJECT_NAME = \"3LC Tutorials - CIFAR-10\"\n", "RUN_NAME = \"Train autoencoder\"\n", "RUN_DESCRIPTION = \"Train an autoencoder and collect embeddings and reconstructions\"\n", "BACKBONE = \"resnet50\"\n", "EMBEDDING_DIM = 512 # Desired embedding dimension\n", "EPOCHS = 10\n", "FREEZE_ENCODER = False\n", "IMAGE_WIDTH = 32\n", "IMAGE_HEIGHT = 32\n", "NUM_CHANNELS = 3\n", "BATCH_SIZE = 64\n", "METHOD = \"pacmap\"\n", "NUM_COMPONENTS = 2\n", "NUM_WORKERS = 0" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "CHECKPOINT_PATH = DOWNLOAD_PATH + \"/autoencoder_model.pth\"" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Load input Table" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "train_table = tlc.Table.from_names(\"initial\", \"CIFAR-10-train\", \"3LC Tutorials - CIFAR-10\")\n", "val_table = tlc.Table.from_names(\"initial\", \"CIFAR-10-val\", \"3LC Tutorials - CIFAR-10\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Prepare Data\n", "transform = transforms.Compose(\n", " [\n", " transforms.ToTensor(),\n", " ]\n", ")\n", "\n", "\n", "def map_fn(sample):\n", " \"\"\"Map samples from the table to be compatible with the model.\"\"\"\n", " image = sample[0]\n", " image = transform(image)\n", " return image\n", "\n", "\n", "train_table.clear_maps()\n", "train_table.map(map_fn)\n", "\n", "val_table.clear_maps()\n", "val_table.map(map_fn)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train autoencoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "class Autoencoder(nn.Module):\n", " def __init__(self, backbone_name=\"resnet50\", embedding_dim=512, freeze_encoder=FREEZE_ENCODER):\n", " super().__init__()\n", "\n", " # Load the backbone as an encoder\n", " self.encoder = create_model(backbone_name, pretrained=True, num_classes=0)\n", " encoder_output_dim = self.encoder.feature_info[-1][\"num_chs\"]\n", "\n", " # Freeze encoder parameters if specified\n", " if freeze_encoder:\n", " for param in self.encoder.parameters():\n", " param.requires_grad = False\n", "\n", " # Add a projection layer to reduce to embedding_dim\n", " self.projector = nn.Linear(encoder_output_dim, embedding_dim)\n", "\n", " # Define the decoder\n", " self.decoder = nn.Sequential(\n", " nn.Linear(embedding_dim, encoder_output_dim),\n", " nn.ReLU(),\n", " nn.Linear(encoder_output_dim, IMAGE_HEIGHT * IMAGE_WIDTH * NUM_CHANNELS),\n", " nn.Sigmoid(),\n", " )\n", "\n", " def forward(self, x):\n", " # Encoder\n", " features = self.encoder(x)\n", " embeddings = self.projector(features)\n", "\n", " # Decoder\n", " reconstructions = self.decoder(embeddings)\n", " reconstructions = reconstructions.view(x.size(0), NUM_CHANNELS, IMAGE_WIDTH, IMAGE_HEIGHT)\n", " return embeddings, reconstructions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Initialize the model\n", "model = Autoencoder(backbone_name=BACKBONE, embedding_dim=EMBEDDING_DIM)\n", "\n", "# Training Components\n", "criterion = nn.MSELoss() # Reconstruction loss\n", "optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)\n", "\n", "# Create data loaders\n", "train_loader = DataLoader(train_table, batch_size=BATCH_SIZE, shuffle=True)\n", "val_loader = DataLoader(val_table, batch_size=BATCH_SIZE, shuffle=False)\n", "\n", "device = infer_torch_device()\n", "model.to(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Training loop\n", "for epoch in range(EPOCHS):\n", " model.train()\n", " epoch_train_loss = 0.0\n", " epoch_val_loss = 0.0\n", "\n", " for images in tqdm(train_loader, desc=\"Training\", total=len(train_loader)):\n", " images = images.to(device)\n", "\n", " # Forward pass\n", " embeddings, reconstructions = model(images)\n", "\n", " # Compute loss\n", " loss = criterion(reconstructions, images)\n", "\n", " # Backward pass and optimization\n", " optimizer.zero_grad()\n", " loss.backward()\n", " optimizer.step()\n", "\n", " epoch_train_loss += loss.item()\n", "\n", " # Validation pass\n", " model.eval()\n", " with torch.no_grad():\n", " for images in tqdm(val_loader, desc=\"Validation\", total=len(val_loader)):\n", " images = images.to(device)\n", "\n", " # Forward pass\n", " embeddings, reconstructions = model(images)\n", "\n", " # Compute loss\n", " loss = criterion(reconstructions, images)\n", "\n", " epoch_val_loss += loss.item()\n", "\n", " epoch_train_loss /= len(train_loader)\n", " epoch_val_loss /= len(val_loader)\n", "\n", " print(f\"Epoch [{epoch + 1}/{EPOCHS}], Train Loss: {epoch_train_loss:.4f}, Val Loss: {epoch_val_loss:.4f}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Save the model\n", "torch.save(model.state_dict(), CHECKPOINT_PATH)\n", "print(f\"Model saved to {CHECKPOINT_PATH}\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Collect metrics from the trained model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "unreduced_loss = nn.MSELoss(reduction=\"none\") # Reconstruction loss\n", "\n", "\n", "def metrics_fn(batch, predictor_output):\n", " embeddings, reconstructions = predictor_output.forward\n", " reconstructed_images = [transforms.ToPILImage()(image.cpu()) for image in reconstructions]\n", " reconstruction_loss = unreduced_loss(reconstructions.to(device), batch.to(device)).mean(dim=(1, 2, 3))\n", " return {\n", " \"embeddings\": embeddings.cpu().detach().numpy(),\n", " \"reconstructions\": reconstructed_images,\n", " \"reconstruction_loss\": reconstruction_loss.cpu().detach().numpy(),\n", " }" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run = tlc.init(project_name=PROJECT_NAME, run_name=RUN_NAME, description=RUN_DESCRIPTION)\n", "\n", "tlc.collect_metrics(\n", " train_table,\n", " metrics_fn,\n", " model,\n", " collect_aggregates=False,\n", " dataloader_args={\"batch_size\": BATCH_SIZE, \"num_workers\": NUM_WORKERS},\n", ")\n", "\n", "tlc.collect_metrics(\n", " val_table,\n", " metrics_fn,\n", " model,\n", " collect_aggregates=False,\n", " dataloader_args={\"batch_size\": BATCH_SIZE, \"num_workers\": NUM_WORKERS},\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Reduce embeddings to 2D" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "run.reduce_embeddings_by_foreign_table_url(\n", " train_table.url,\n", " source_embedding_column=\"embeddings\",\n", " method=METHOD,\n", " n_components=NUM_COMPONENTS,\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.9" }, "test_marks": [ "slow" ] }, "nbformat": 4, "nbformat_minor": 2 }