{
"cells": [
{
"cell_type": "markdown",
"id": "3a609692",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"# Neural Style Transfer\n",
"\n",
"Reading the Content and Style Images"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "e0868e28",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:52:55.964795Z",
"iopub.status.busy": "2023-08-18T19:52:55.964503Z",
"iopub.status.idle": "2023-08-18T19:52:59.646946Z",
"shell.execute_reply": "2023-08-18T19:52:59.645534Z"
},
"origin_pos": 2,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"%matplotlib inline\n",
"import torch\n",
"import torchvision\n",
"from torch import nn\n",
"from d2l import torch as d2l\n",
"\n",
"d2l.set_figsize()\n",
"content_img = d2l.Image.open('../img/rainier.jpg')\n",
"d2l.plt.imshow(content_img);"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "283f5e51",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:52:59.651637Z",
"iopub.status.busy": "2023-08-18T19:52:59.650615Z",
"iopub.status.idle": "2023-08-18T19:53:00.264518Z",
"shell.execute_reply": "2023-08-18T19:53:00.263173Z"
},
"origin_pos": 4,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"style_img = d2l.Image.open('../img/autumn-oak.jpg')\n",
"d2l.plt.imshow(style_img);"
]
},
{
"cell_type": "markdown",
"id": "38dcedd0",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Preprocessing and Postprocessing"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "9f1ef9cd",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:00.269093Z",
"iopub.status.busy": "2023-08-18T19:53:00.268290Z",
"iopub.status.idle": "2023-08-18T19:53:00.275592Z",
"shell.execute_reply": "2023-08-18T19:53:00.274696Z"
},
"origin_pos": 7,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"rgb_mean = torch.tensor([0.485, 0.456, 0.406])\n",
"rgb_std = torch.tensor([0.229, 0.224, 0.225])\n",
"\n",
"def preprocess(img, image_shape):\n",
" transforms = torchvision.transforms.Compose([\n",
" torchvision.transforms.Resize(image_shape),\n",
" torchvision.transforms.ToTensor(),\n",
" torchvision.transforms.Normalize(mean=rgb_mean, std=rgb_std)])\n",
" return transforms(img).unsqueeze(0)\n",
"\n",
"def postprocess(img):\n",
" img = img[0].to(rgb_std.device)\n",
" img = torch.clamp(img.permute(1, 2, 0) * rgb_std + rgb_mean, 0, 1)\n",
" return torchvision.transforms.ToPILImage()(img.permute(2, 0, 1))"
]
},
{
"cell_type": "markdown",
"id": "c04a6f06",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Extracting Features"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "1e1a4a43",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:00.278914Z",
"iopub.status.busy": "2023-08-18T19:53:00.278636Z",
"iopub.status.idle": "2023-08-18T19:53:04.940646Z",
"shell.execute_reply": "2023-08-18T19:53:04.939402Z"
},
"origin_pos": 10,
"tab": [
"pytorch"
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Downloading: \"https://download.pytorch.org/models/vgg19-dcbb9e9d.pth\" to /home/ci/.cache/torch/hub/checkpoints/vgg19-dcbb9e9d.pth\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 0%| | 0.00/548M [00:00, ?B/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 0%| | 2.07M/548M [00:00<00:26, 21.7MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 1%| | 5.80M/548M [00:00<00:17, 31.9MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 2%|▏ | 13.7M/548M [00:00<00:10, 54.6MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 5%|▌ | 28.1M/548M [00:00<00:05, 92.4MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 7%|▋ | 40.4M/548M [00:00<00:05, 106MB/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 11%|█ | 61.2M/548M [00:00<00:03, 144MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 16%|█▌ | 85.1M/548M [00:00<00:02, 179MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 20%|█▉ | 110M/548M [00:00<00:02, 203MB/s] "
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 24%|██▎ | 129M/548M [00:00<00:02, 205MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 28%|██▊ | 153M/548M [00:01<00:01, 218MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 32%|███▏ | 177M/548M [00:01<00:01, 230MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 36%|███▋ | 199M/548M [00:01<00:02, 174MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 40%|████ | 221M/548M [00:01<00:01, 188MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 45%|████▍ | 245M/548M [00:01<00:01, 205MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 49%|████▉ | 269M/548M [00:01<00:01, 217MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 53%|█████▎ | 293M/548M [00:01<00:01, 225MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 58%|█████▊ | 316M/548M [00:01<00:01, 232MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 62%|██████▏ | 340M/548M [00:01<00:00, 237MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 66%|██████▋ | 364M/548M [00:02<00:00, 240MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 71%|███████ | 389M/548M [00:02<00:00, 247MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 76%|███████▌ | 415M/548M [00:02<00:00, 256MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 81%|████████ | 442M/548M [00:02<00:00, 264MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 86%|████████▌ | 469M/548M [00:02<00:00, 270MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 91%|█████████ | 496M/548M [00:02<00:00, 273MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
" 95%|█████████▌| 523M/548M [00:02<00:00, 276MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\r",
"100%|██████████| 548M/548M [00:02<00:00, 213MB/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n"
]
}
],
"source": [
"pretrained_net = torchvision.models.vgg19(pretrained=True)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "1caba7d6",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:04.970990Z",
"iopub.status.busy": "2023-08-18T19:53:04.970462Z",
"iopub.status.idle": "2023-08-18T19:53:04.975342Z",
"shell.execute_reply": "2023-08-18T19:53:04.974495Z"
},
"origin_pos": 20,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"style_layers, content_layers = [0, 5, 10, 19, 28], [25]\n",
"\n",
"net = nn.Sequential(*[pretrained_net.features[i] for i in\n",
" range(max(content_layers + style_layers) + 1)])\n",
"\n",
"def extract_features(X, content_layers, style_layers):\n",
" contents = []\n",
" styles = []\n",
" for i in range(len(net)):\n",
" X = net[i](X)\n",
" if i in style_layers:\n",
" styles.append(X)\n",
" if i in content_layers:\n",
" contents.append(X)\n",
" return contents, styles\n",
"\n",
"def get_contents(image_shape, device):\n",
" content_X = preprocess(content_img, image_shape).to(device)\n",
" contents_Y, _ = extract_features(content_X, content_layers, style_layers)\n",
" return content_X, contents_Y\n",
"\n",
"def get_styles(image_shape, device):\n",
" style_X = preprocess(style_img, image_shape).to(device)\n",
" _, styles_Y = extract_features(style_X, content_layers, style_layers)\n",
" return style_X, styles_Y"
]
},
{
"cell_type": "markdown",
"id": "fcbb50a4",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Defining the Loss Function"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "761c6242",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:05.004270Z",
"iopub.status.busy": "2023-08-18T19:53:05.003997Z",
"iopub.status.idle": "2023-08-18T19:53:05.008815Z",
"shell.execute_reply": "2023-08-18T19:53:05.008035Z"
},
"origin_pos": 30,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def content_loss(Y_hat, Y):\n",
" return torch.square(Y_hat - Y.detach()).mean()\n",
"\n",
"def gram(X):\n",
" num_channels, n = X.shape[1], X.numel() // X.shape[1]\n",
" X = X.reshape((num_channels, n))\n",
" return torch.matmul(X, X.T) / (num_channels * n)\n",
"\n",
"def style_loss(Y_hat, gram_Y):\n",
" return torch.square(gram(Y_hat) - gram_Y.detach()).mean()\n",
"\n",
"def tv_loss(Y_hat):\n",
" return 0.5 * (torch.abs(Y_hat[:, :, 1:, :] - Y_hat[:, :, :-1, :]).mean() +\n",
" torch.abs(Y_hat[:, :, :, 1:] - Y_hat[:, :, :, :-1]).mean())"
]
},
{
"cell_type": "markdown",
"id": "b799ff1e",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"The loss function of style transfer is the weighted sum of content loss, style loss, and total variation loss"
]
},
{
"cell_type": "code",
"execution_count": 13,
"id": "fa523eb8",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:05.012308Z",
"iopub.status.busy": "2023-08-18T19:53:05.011878Z",
"iopub.status.idle": "2023-08-18T19:53:05.019144Z",
"shell.execute_reply": "2023-08-18T19:53:05.018338Z"
},
"origin_pos": 32,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"content_weight, style_weight, tv_weight = 1, 1e4, 10\n",
"\n",
"def compute_loss(X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram):\n",
" contents_l = [content_loss(Y_hat, Y) * content_weight for Y_hat, Y in zip(\n",
" contents_Y_hat, contents_Y)]\n",
" styles_l = [style_loss(Y_hat, Y) * style_weight for Y_hat, Y in zip(\n",
" styles_Y_hat, styles_Y_gram)]\n",
" tv_l = tv_loss(X) * tv_weight\n",
" l = sum(styles_l + contents_l + [tv_l])\n",
" return contents_l, styles_l, tv_l, l"
]
},
{
"cell_type": "markdown",
"id": "65fef189",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Initializing the Synthesized Image"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "4141c09a",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:05.030294Z",
"iopub.status.busy": "2023-08-18T19:53:05.029635Z",
"iopub.status.idle": "2023-08-18T19:53:05.034918Z",
"shell.execute_reply": "2023-08-18T19:53:05.033861Z"
},
"origin_pos": 38,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"class SynthesizedImage(nn.Module):\n",
" def __init__(self, img_shape, **kwargs):\n",
" super(SynthesizedImage, self).__init__(**kwargs)\n",
" self.weight = nn.Parameter(torch.rand(*img_shape))\n",
"\n",
" def forward(self):\n",
" return self.weight\n",
"\n",
"def get_inits(X, device, lr, styles_Y):\n",
" gen_img = SynthesizedImage(X.shape).to(device)\n",
" gen_img.weight.data.copy_(X.data)\n",
" trainer = torch.optim.Adam(gen_img.parameters(), lr=lr)\n",
" styles_Y_gram = [gram(Y) for Y in styles_Y]\n",
" return gen_img(), styles_Y_gram, trainer"
]
},
{
"cell_type": "markdown",
"id": "28e26b15",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Training"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "b89e276f",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:05.039045Z",
"iopub.status.busy": "2023-08-18T19:53:05.038294Z",
"iopub.status.idle": "2023-08-18T19:53:05.045661Z",
"shell.execute_reply": "2023-08-18T19:53:05.044851Z"
},
"origin_pos": 41,
"tab": [
"pytorch"
]
},
"outputs": [],
"source": [
"def train(X, contents_Y, styles_Y, device, lr, num_epochs, lr_decay_epoch):\n",
" X, styles_Y_gram, trainer = get_inits(X, device, lr, styles_Y)\n",
" scheduler = torch.optim.lr_scheduler.StepLR(trainer, lr_decay_epoch, 0.8)\n",
" animator = d2l.Animator(xlabel='epoch', ylabel='loss',\n",
" xlim=[10, num_epochs],\n",
" legend=['content', 'style', 'TV'],\n",
" ncols=2, figsize=(7, 2.5))\n",
" for epoch in range(num_epochs):\n",
" trainer.zero_grad()\n",
" contents_Y_hat, styles_Y_hat = extract_features(\n",
" X, content_layers, style_layers)\n",
" contents_l, styles_l, tv_l, l = compute_loss(\n",
" X, contents_Y_hat, styles_Y_hat, contents_Y, styles_Y_gram)\n",
" l.backward()\n",
" trainer.step()\n",
" scheduler.step()\n",
" if (epoch + 1) % 10 == 0:\n",
" animator.axes[1].imshow(postprocess(X))\n",
" animator.add(epoch + 1, [float(sum(contents_l)),\n",
" float(sum(styles_l)), float(tv_l)])\n",
" return X"
]
},
{
"cell_type": "markdown",
"id": "45b43cc9",
"metadata": {
"slideshow": {
"slide_type": "slide"
}
},
"source": [
"Start to train the model"
]
},
{
"cell_type": "code",
"execution_count": 17,
"id": "55724054",
"metadata": {
"execution": {
"iopub.execute_input": "2023-08-18T19:53:05.049186Z",
"iopub.status.busy": "2023-08-18T19:53:05.048632Z",
"iopub.status.idle": "2023-08-18T19:54:01.438578Z",
"shell.execute_reply": "2023-08-18T19:54:01.437688Z"
},
"origin_pos": 44,
"tab": [
"pytorch"
]
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"device, image_shape = d2l.try_gpu(), (300, 450)\n",
"net = net.to(device)\n",
"content_X, contents_Y = get_contents(image_shape, device)\n",
"_, styles_Y = get_styles(image_shape, device)\n",
"output = train(content_X, contents_Y, styles_Y, device, 0.3, 500, 50)"
]
}
],
"metadata": {
"celltoolbar": "Slideshow",
"language_info": {
"name": "python"
},
"required_libs": [],
"rise": {
"autolaunch": true,
"enable_chalkboard": true,
"overlay": "",
"scroll": true
}
},
"nbformat": 4,
"nbformat_minor": 5
}