{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "<img src=\"pic.png\" width=\"1000px\">" ] }, { "cell_type": "code", "execution_count": 57, "metadata": {}, "outputs": [], "source": [ "import torch\n", "from torchvision import datasets, transforms\n", "from torch.utils.data import DataLoader\n", "import matplotlib.pyplot as plt\n", "import torchvision.transforms as transforms\n", "from sklearn.decomposition import PCA\n", "import torch.nn as nn\n", "from einops import rearrange\n", "from torchsummary import summary\n", "import torch.optim as optim\n", "import math\n", "from tqdm import tqdm\n", "import cv2\n", "import numpy as np\n", "from torch import Tensor\n", "\n", "\n", "\n", "import lightning as L" ] }, { "cell_type": "code", "execution_count": 58, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "cpu\n" ] } ], "source": [ "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n", "print(device)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Set global seed to compare same initialized weights\n", "torch.manual_seed(0)" ] }, { "cell_type": "code", "execution_count": 59, "metadata": {}, "outputs": [], "source": [ "class VQ_VAE(L.LightningModule):\n", " def __init__(self, input_dim, codebook_size, encoding_dim, hidden_dim, beta):\n", " super().__init__()\n", "\n", " assert len(input_dim) == 3, \"Input dimension must be 3D\"\n", " assert len(encoding_dim) == 3, \"Encoding dimension must be 3D\"\n", "\n", " # Set the dimensions\n", " self.input_dim = input_dim\n", " self.codebook_size = codebook_size\n", " self.encoding_dim = encoding_dim\n", " self.hidden_dim = hidden_dim\n", " self.beta = beta\n", "\n", " # Unwrap input dimension, extract the number of channels, height and width\n", " self._input_c = input_dim[0]\n", " self._input_h = input_dim[1]\n", " self._input_w = input_dim[2]\n", " self._input_dim_flat = self._input_c * self._input_h * self._input_w\n", "\n", " # Unwrap encoding dimension\n", " self._embedding_dim = encoding_dim[0]\n", " self._encoding_h = encoding_dim[1]\n", " self._encoding_w = encoding_dim[2]\n", "\n", " # Calculate flat encoding dimension\n", " self._encoding_dim_flat = self._embedding_dim * self._encoding_h * self._encoding_w\n", " \n", "\n", " # Setup encoder layers\n", " self.enc_fc1 = nn.Linear(self._input_dim_flat, self.hidden_dim)\n", " self.enc_fc2 = nn.Linear(self.hidden_dim, self._encoding_dim_flat)\n", " self.relu = nn.ReLU()\n", "\n", " # Setup decoder layers\n", " self.dec_fc1 = nn.Linear(self._encoding_dim_flat, self.hidden_dim)\n", " self.dec_fc2 = nn.Linear(self.hidden_dim, self._input_dim_flat)\n", " self.sigmoid = nn.Sigmoid()\n", "\n", " # Setup codebook embeddings\n", " self.codebook = nn.Embedding(self.codebook_size, self._embedding_dim)\n", " assert self.codebook.weight.requires_grad == True, \"Codebook should be learnable\"\n", " # TODO: Initialize codebook uniformly, like in the paper\n", " nn.init.uniform_(self.codebook.weight, -1/self.codebook_size, 1/self.codebook_size)\n", "\n", " # For now we initialize randomly the codebook vectors. We act like stupid, then if it works, we use uniform like in other implementations.\n", " \n", " def encode(self, x):\n", " \"\"\"\n", " x: Input image of shape (b, c, h, w)\n", " Output: z - encoded image of shape (b, h, w, c)\n", " \"\"\"\n", " assert x.shape[1:] == (self._input_c, self._input_h, self._input_w)\n", "\n", " # Run the layers\n", " x = rearrange(x, 'b c h w -> b (c h w)') # Flatten the input\n", " x = self.relu(self.enc_fc1(x))\n", " x = self.relu(self.enc_fc2(x))\n", "\n", " # Unflatten, last dimension is embedding dim\n", " x = rearrange(x, 'b (c h w) -> b h w c', c=self._embedding_dim, h=self._encoding_h, w=self._encoding_w)\n", "\n", " return x\n", " \n", " def decode(self, z_q):\n", " \"\"\" \n", " Input: z_q: Quantized encodings of shape (b h w c)\n", "\n", " Returns the reconstructed image x_hat of shape (b c h w)\n", " \"\"\"\n", " # Rearrange so we work on the values, also flatten because of fully-connected layers\n", " z_q = rearrange(z_q, 'b h w c -> b (c h w)')\n", "\n", " # Run the layers\n", " x = self.relu(self.dec_fc1(z_q))\n", " x = self.relu(self.dec_fc2(x))\n", " x_hat = self.sigmoid(x)\n", "\n", " # Rearrange back to image shape\n", " x_hat = rearrange(x_hat, 'b (c h w) -> b c h w', c=self._input_c, h=self._input_h, w=self._input_w)\n", "\n", " return x_hat\n", " \n", " def quantize(self, z):\n", " \"\"\"\n", " z_e: (batch_size, encoding_height, encoding_width, embedding_dim)\n", "\n", " Get closest (euclidean distance) codebook vector z_q given z.\n", "\n", " Returns quantized encodings z_q of shape (b h w c)\n", " \"\"\"\n", " assert z.shape[1:] == (self._encoding_h, self._encoding_w, self._embedding_dim)\n", "\n", " flat_input = rearrange(z, 'b h w c -> (b h w) c')\n", "\n", " # TODO: Replace this by my own distances calculation\n", " # Equation: ||z - e||^2 = ||z||^2 + ||e||^2 - 2 * <z,e>\n", " distances = (torch.sum(flat_input**2, dim=1, keepdim=True)\n", " + torch.sum(self.codebook.weight**2, dim=1)\n", " - 2 * torch.matmul(flat_input, self.codebook.weight.t()))\n", "\n", " #distances = self.calculate_distances(z)\n", "\n", " # Distances is shape (batch, height, width, embed_dim)\n", " # Get the index of the closest codebook vector\n", " min_indices = torch.argmin(distances, dim=1)\n", "\n", " # We now have matrix of one-hot vectors, with \"1\" at the index of the closest codebook vector, given by \"min_indices\"\n", " one_hot = nn.functional.one_hot(min_indices, num_classes=self.codebook_size)\n", "\n", " # Now to get the value of the closest codebook vector, we multiply the one-hot matrix with the codebook matrix\n", " z_q = torch.matmul(one_hot.float(), self.codebook.weight)\n", "\n", " # Reshape back\n", " z_q = rearrange(z_q, '(b h w) c -> b h w c', b=z.shape[0], h=self._encoding_h, w=self._encoding_w)\n", "\n", " return z_q\n", "\n", " def forward(self, x):\n", " assert x.shape[1:] == (self._input_c, self._input_h, self._input_w)\n", "\n", " # Encode\n", " z = self.encode(x)\n", "\n", " # Quantize the latent vector\n", " z_q = self.quantize(z)\n", "\n", " # Add streight through estimator\n", " z_q = z + (z_q - z).detach()\n", " \n", " # Decode\n", " x_reconstructed = self.decode(z_q)\n", "\n", " # The output image should have the same shape as the input image\n", " assert x_reconstructed.shape == x.shape\n", "\n", " # Return x hat (and also some other stuff for loss calculation and debugging)\n", " return x_reconstructed, z, z_q\n", "\n", " def training_step(self, batch, batch_idx):\n", " images, _ = batch\n", " \n", " # Forward pass\n", " x_hat, z_e, z_q = self(images)\n", "\n", " # Calculate loss\n", " # Reconstruction loss\n", " recon_loss = nn.BCELoss(reduction='sum')(x_hat, images)\n", " # recon_loss = nn.functional.mse_loss(x_hat, images)\n", "\n", " # Quantization loss\n", " quant_loss = nn.functional.mse_loss(z_e.detach(), z_q) # TODO: This loss is increasing\n", " # quant_loss = torch.mean(torch.norm((emb - z_e.detach())**2, 2, 1))\n", "\n", " # Commitment loss\n", " commit_loss = nn.functional.mse_loss(z_q.detach(), z_e)\n", "\n", " # Total loss\n", " loss = recon_loss + quant_loss + self.beta * commit_loss\n", "\n", " values = {\"loss\": loss, \"recon_loss\": recon_loss, \"quant_loss\": quant_loss, \"commit_loss\": commit_loss}\n", " self.log_dict(values)\n", "\n", " return loss\n", " \n", " def configure_optimizers(self):\n", " lr = 1e-3\n", " optimizer = optim.Adam(self.parameters(), lr=lr)\n", " return optimizer\n" ] }, { "cell_type": "code", "execution_count": 60, "metadata": {}, "outputs": [], "source": [ "# Model\n", "input_dim = (3, 32, 32)\n", "codebook_size = 512\n", "\n", "embedding_dim = 64 # Dimension of each codebook vector\n", "encoding_dim = (embedding_dim, 8, 8)\n", "\n", "hidden_dim = 1024\n", "\n", "beta = 0.25" ] }, { "cell_type": "code", "execution_count": 61, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "----------------------------------------------------------------\n", " Layer (type) Output Shape Param #\n", "================================================================\n", " Linear-1 [-1, 1024] 3,146,752\n", " ReLU-2 [-1, 1024] 0\n", " Linear-3 [-1, 4096] 4,198,400\n", " ReLU-4 [-1, 4096] 0\n", " Linear-5 [-1, 1024] 4,195,328\n", " ReLU-6 [-1, 1024] 0\n", " Linear-7 [-1, 3072] 3,148,800\n", " ReLU-8 [-1, 3072] 0\n", " Sigmoid-9 [-1, 3072] 0\n", "================================================================\n", "Total params: 14,689,280\n", "Trainable params: 14,689,280\n", "Non-trainable params: 0\n", "----------------------------------------------------------------\n", "Input size (MB): 0.01\n", "Forward/backward pass size (MB): 0.16\n", "Params size (MB): 56.04\n", "Estimated Total Size (MB): 56.21\n", "----------------------------------------------------------------\n" ] } ], "source": [ "model = VQ_VAE(input_dim, codebook_size, encoding_dim, hidden_dim, beta)\n", "model.to(device)\n", "summary(model, (3, 32, 32))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Quantization of z_e\n", "\n", "To quantisize the z_e we run `torch.norm()` which calculate length of vector.\n", "\n", "For instance if `a = torch.tensor([1, 1])` then the length of vector is `sqrt(1^2 + 1^2) = sqrt(2) = 1.4142`\n", "\n", "So we do: `torch.norm(z_e - codebook)` which means we measure distance between two vectors.\n", "\n", "Then we apply `argmin`: `torch.argmin(torch.norm(z_e - codebook))` to get the index of the closest vector in the codebook.\n", "\n", "Finally we get the quantized vector: `codebook[torch.argmin(torch.norm(z_e - codebook))]`" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Define dataset and dataloader" ] }, { "cell_type": "code", "execution_count": 62, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Files already downloaded and verified\n", "Files already downloaded and verified\n" ] } ], "source": [ "# Transformations\n", "# Note: normalizing the images to have values in the range [0, 1] is important for the calculation of BCE loss\n", "transform = transforms.Compose([\n", " transforms.ToTensor(),\n", " lambda x: (x - x.min()) / (x.max() - x.min()) # Normalize to [0, 1]\n", "])\n", "\n", "# Load dataset\n", "# CIFAR-10\n", "train_dataset = datasets.CIFAR10(root=\"./data\", train=True, download=True, transform=transform)\n", "test_dataset = datasets.CIFAR10(root=\"./data\", train=False, download=True, transform=transform)\n", "\n", "# Data loaders\n", "test_loader_batch = 32\n", "train_loader = DataLoader(train_dataset, shuffle=True)\n", "test_loader = DataLoader(test_dataset, batch_size=test_loader_batch, shuffle=False)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Train" ] }, { "cell_type": "code", "execution_count": 63, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "GPU available: True (mps), used: True\n", "TPU available: False, using: 0 TPU cores\n", "HPU available: False, using: 0 HPUs\n", "\n", " | Name | Type | Params | Mode \n", "-----------------------------------------------\n", "0 | enc_fc1 | Linear | 3.1 M | train\n", "1 | enc_fc2 | Linear | 4.2 M | train\n", "2 | relu | ReLU | 0 | train\n", "3 | dec_fc1 | Linear | 4.2 M | train\n", "4 | dec_fc2 | Linear | 3.1 M | train\n", "5 | sigmoid | Sigmoid | 0 | train\n", "6 | codebook | Embedding | 32.8 K | train\n", "-----------------------------------------------\n", "14.7 M Trainable params\n", "0 Non-trainable params\n", "14.7 M Total params\n", "58.888 Total estimated model params size (MB)\n", "7 Modules in train mode\n", "0 Modules in eval mode\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4: 100%|██████████| 256/256 [00:08<00:00, 28.57it/s, v_num=22]" ] }, { "name": "stderr", "output_type": "stream", "text": [ "`Trainer.fit` stopped: `max_epochs=5` reached.\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ "Epoch 4: 100%|██████████| 256/256 [00:09<00:00, 27.86it/s, v_num=22]\n" ] } ], "source": [ "trainer = L.Trainer(limit_train_batches=256, max_epochs=5)\n", "trainer.fit(model=model, train_dataloaders=train_loader)" ] }, { "cell_type": "code", "execution_count": 64, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "torch.Size([32, 3, 32, 32])\n", "torch.Size([8, 3, 32, 32])\n" ] }, { "data": { "image/png": "", "text/plain": [ "<Figure size 2000x500 with 16 Axes>" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "# Plot 8 images (original image) and to their right plot the reconstructed image\n", "model.eval()\n", "with torch.no_grad():\n", " images, labels = next(iter(test_loader))\n", " assert len(images) >= 8, \"Test loader must have at least 8 images, got only {}\".format(len(images))\n", " print(images.shape)\n", " images = images[:8].to(device)\n", "\n", " outputs, _, _ = model(images)\n", "\n", " images = images.cpu()\n", " outputs = outputs.cpu()\n", "\n", " print(outputs.shape)\n", "\n", " fig, axes = plt.subplots(2, 8, figsize=(20, 5))\n", " for i in range(8):\n", " # Change the order of the channels to (h, w, c) for matplotlib\n", " axes[0, i].imshow(images[i].permute(1, 2, 0))\n", " axes[0, i].axis('off')\n", " # Add title for the image\n", " axes[0, i].set_title(f\"{train_dataset.classes[labels[i]]}\")\n", " axes[1, i].imshow(outputs[i].permute(1, 2, 0))\n", " axes[1, i].axis('off')\n", "\n", " plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "base", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.2" } }, "nbformat": 4, "nbformat_minor": 2 }