{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "from PIL import Image\n",
        "from torchvision import transforms\n",
        "from torchvision import models\n",
        "from torch import nn\n",
        "from torch import optim\n",
        "import torch\n",
        "from tqdm import tqdm\n",
        "from torchvision.utils import save_image"
      ],
      "metadata": {
        "id": "mYwaMqzzRqvE"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yRymEu3xQy4s"
      },
      "outputs": [],
      "source": [
        "def read_images(img_path):\n",
        "  img = Image.open(img_path).convert('RGB')\n",
        "  transform_1 = transforms.Compose([transforms.Resize((512,512)),\n",
        "                                    transforms.ToTensor()])\n",
        "  return transform_1(img).to('cuda', torch.float)"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "class StyleTransferModel(nn.Module):\n",
        "  def __init__(self):\n",
        "    super(StyleTransferModel, self).__init__()\n",
        "    self.model = models.vgg19(pretrained=True).features[:29]\n",
        "    self.random_layers = [0, 5, 10, 19, 28]\n",
        "\n",
        "  def forward(self, x):\n",
        "    features = []\n",
        "    for layer_idx, layer in enumerate(self.model):\n",
        "      x = layer(x)\n",
        "      if layer_idx in self.random_layers:\n",
        "        features.append(x)\n",
        "    return features"
      ],
      "metadata": {
        "id": "qA7l038XTTOc"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def content_loss(x, content_img):\n",
        "  return torch.mean((x-content_img)**2)"
      ],
      "metadata": {
        "id": "rAFMEF7fUwmL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def style_loss(x, style_image):\n",
        "  c, h, w = x.shape\n",
        "  x_1 = x.reshape(c, h*w)\n",
        "  style_image_1 = style_image.reshape(c, h*w)\n",
        "\n",
        "  G_x = torch.mm(x_1, x_1.permute(1, 0))\n",
        "  G_style_image = torch.mm(style_image_1, style_image_1.permute(1, 0))\n",
        "\n",
        "  return torch.mean((G_x - G_style_image)**2)"
      ],
      "metadata": {
        "id": "9vmiZFO2VIvL"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "alpha, beta = 8, 70\n",
        "def calculate_loss(x_f, content_f, style_f):\n",
        "  style_l = content_l = 0\n",
        "  for x, c, s in zip(x_f, content_f, style_f):\n",
        "    content_l += content_loss(x, c)\n",
        "    style_l += style_loss(x, s)\n",
        "\n",
        "  total_loss = alpha * content_l + beta * style_l\n",
        "  return total_loss"
      ],
      "metadata": {
        "id": "Yca7b1AoY5sO"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "style = read_images(\"/content/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg\")\n",
        "content = read_images(\"/content/Wolverine.png\")\n",
        "x = content.clone().requires_grad_(True)"
      ],
      "metadata": {
        "id": "vNxz9lHAVz_4"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "optimizer = optim.Adam([x], lr=0.004)\n",
        "model = StyleTransferModel().eval().to('cuda')"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "-1uvpFLOWGMS",
        "outputId": "df6b5d84-e36f-46b7-c678-717f69957d34"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.\n",
            "  warnings.warn(\n",
            "/usr/local/lib/python3.11/dist-packages/torchvision/models/_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=VGG19_Weights.IMAGENET1K_V1`. You can also use `weights=VGG19_Weights.DEFAULT` to get the most up-to-date weights.\n",
            "  warnings.warn(msg)\n"
          ]
        }
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "progress_bar = tqdm(range(7000), desc=\"optimizing\")\n",
        "for i in progress_bar:\n",
        "  x_features = model(x)\n",
        "  style_features = model(style)\n",
        "  content_features = model(content)\n",
        "  total_loss = calculate_loss(x_features, content_features, style_features)\n",
        "\n",
        "  optimizer.zero_grad()\n",
        "  total_loss.backward()\n",
        "  optimizer.step()\n",
        "  progress_bar.set_postfix({\"total_loss\": f\"{total_loss.item()}\"})\n",
        "  if i % 100 == 0:\n",
        "    save_image(x, f'./output/wolvi_{i}.png')\n",
        "\n",
        "\n"
      ],
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "p6CLgdu4XDBK",
        "outputId": "cba92fcd-1168-49a0-a062-26804d5f9446"
      },
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "name": "stderr",
          "text": [
            "optimizing: 100%|██████████| 7000/7000 [42:01<00:00,  2.78it/s, total_loss=2417662.75]\n"
          ]
        }
      ]
    }
  ]
}