In [None]:
from PIL import Image
from torchvision import transforms
from torchvision import models
from torch import nn
from torch import optim
import torch
from tqdm import tqdm
from torchvision.utils import save_image

In [None]:
def read_images(img_path):
 img = Image.open(img_path).convert('RGB')
 transform_1 = transforms.Compose([transforms.Resize((512,512)),
 transforms.ToTensor()])
 return transform_1(img).to('cuda', torch.float)

In [None]:
class StyleTransferModel(nn.Module):
 def __init__(self):
 super(StyleTransferModel, self).__init__()
 self.model = models.vgg19(pretrained=True).features[:29]
 self.random_layers = [0, 5, 10, 19, 28]

 def forward(self, x):
 features = []
 for layer_idx, layer in enumerate(self.model):
 x = layer(x)
 if layer_idx in self.random_layers:
 features.append(x)
 return features

In [None]:
def content_loss(x, content_img):
 return torch.mean((x-content_img)**2)

In [None]:
def style_loss(x, style_image):
 c, h, w = x.shape
 x_1 = x.reshape(c, h*w)
 style_image_1 = style_image.reshape(c, h*w)

 G_x = torch.mm(x_1, x_1.permute(1, 0))
 G_style_image = torch.mm(style_image_1, style_image_1.permute(1, 0))

 return torch.mean((G_x - G_style_image)**2)

In [None]:
alpha, beta = 8, 70
def calculate_loss(x_f, content_f, style_f):
 style_l = content_l = 0
 for x, c, s in zip(x_f, content_f, style_f):
 content_l += content_loss(x, c)
 style_l += style_loss(x, s)

 total_loss = alpha * content_l + beta * style_l
 return total_loss

In [None]:
style = read_images("/content/Van_Gogh_-_Starry_Night_-_Google_Art_Project.jpg")
content = read_images("/content/Wolverine.png")
x = content.clone().requires_grad_(True)

In [None]:
optimizer = optim.Adam([x], lr=0.004)
model = StyleTransferModel().eval().to('cuda')



In [None]:
progress_bar = tqdm(range(7000), desc="optimizing")
for i in progress_bar:
 x_features = model(x)
 style_features = model(style)
 content_features = model(content)
 total_loss = calculate_loss(x_features, content_features, style_features)

 optimizer.zero_grad()
 total_loss.backward()
 optimizer.step()
 progress_bar.set_postfix({"total_loss": f"{total_loss.item()}"})
 if i % 100 == 0:
 save_image(x, f'./output/wolvi_{i}.png')




optimizing: 100%|██████████| 7000/7000 [42:01<00:00, 2.78it/s, total_loss=2417662.75]
