{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# DCGAN" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "collapsed": true }, "outputs": [], "source": [ "import torch as t\n", "from torch import nn\n", "from torch.autograd import Variable\n", "from torch.optim import Adam\n", "from torchvision import transforms\n", "from torchvision.utils import save_image\n", "from torchvision.datasets import CIFAR10\n", "import numpy as np\n", "from torch import optim\n", "import torchvision.utils as vutil" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "class Config:\n", " lr=0.0002\n", " nz=100# 噪声维度\n", " image_size=64\n", " image_size2=64\n", " nc=3# 图片三通道\n", " ngf=64 #生成图片\n", " ndf=64 #判别图片\n", " gpuids=None\n", " beta1=0.5\n", " batch_size=32\n", " max_epoch=1# =1 when debug\n", " workers=2\n", " \n", "opt=Config()" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 数据加载和预处理 \n", "dataset=CIFAR10(root='cifar10/',\\\n", " transform=transforms.Compose(\\\n", " [transforms.Scale(opt.image_size) ,\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5]*3,[0.5]*3)\n", " \n", " ]))\n", "# 什么惰性加载,预加载,多线程,乱序 全都解决 \n", "dataloader=t.utils.data.DataLoader(dataset,opt.batch_size,True,num_workers=opt.workers)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "ModelD (\n", " (model): Sequential (\n", " (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (relu1): LeakyReLU (0.2, inplace)\n", " (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " (relu2): LeakyReLU (0.2, inplace)\n", " (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (relu3): LeakyReLU (0.2, inplace)\n", " (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " (relu4): LeakyReLU (0.2, inplace)\n", " (conv5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", " (sigmoid): Sigmoid ()\n", " )\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "#模型定义\n", "class ModelG(nn.Module):\n", " def __init__(self,ngpu):\n", " super(ModelG,self).__init__()\n", " self.ngpu=ngpu\n", " self.model=nn.Sequential()\n", " self.model.add_module('deconv1',nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False))\n", " self.model.add_module('bnorm1',nn.BatchNorm2d(opt.ngf*8))\n", " self.model.add_module('relu1',nn.ReLU(True))\n", " self.model.add_module('deconv2',nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False))\n", " self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ngf*4))\n", " self.model.add_module('relu2',nn.ReLU(True))\n", " self.model.add_module('deconv3',nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False))\n", " self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ngf*2))\n", " self.model.add_module('relu3',nn.ReLU(True))\n", " \n", " self.model.add_module('deconv4',nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False))\n", " self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ngf))\n", " self.model.add_module('relu4',nn.ReLU(True))\n", " \n", " self.model.add_module('deconv5',nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False))\n", " self.model.add_module('tanh',nn.Tanh())\n", " def forward(self,input):\n", " gpuids=None\n", " if self.ngpu:\n", " gpuids=range(gpuids)\n", " return nn.parallel.data_parallel(self.model,input, device_ids=gpuids)\n", "\n", "def weight_init(m):\n", " #模型参数初始化. 可以优化成为xavier 初始化\n", " class_name=m.__class__.__name__\n", " if class_name.find('conv')!=-1:\n", " m.weight.data.normal_(0,0.02)\n", " if class_name.find('norm')!=-1:\n", " m.weight.data.normal_(1.0,0.02)\n", " \n", "class ModelD(nn.Module):\n", " def __init__(self,ngpu):\n", " super(ModelD,self).__init__()\n", " self.ngpu=ngpu\n", " self.model=nn.Sequential()\n", " self.model.add_module('conv1',nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False))\n", " self.model.add_module('relu1',nn.LeakyReLU(0.2,inplace=True))\n", " \n", " self.model.add_module('conv2',nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False))\n", " self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ndf*2))\n", " self.model.add_module('relu2',nn.LeakyReLU(0.2,inplace=True))\n", " \n", " self.model.add_module('conv3',nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False))\n", " self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ndf*4))\n", " self.model.add_module('relu3',nn.LeakyReLU(0.2,inplace=True))\n", " \n", " self.model.add_module('conv4',nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False))\n", " self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ndf*8))\n", " self.model.add_module('relu4',nn.LeakyReLU(0.2,inplace=True))\n", " \n", " self.model.add_module('conv5',nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False))\n", " self.model.add_module('sigmoid',nn.Sigmoid())\n", " def forward(self,input):\n", " gpuids=None\n", " if self.ngpu:\n", " gpuids=range(gpuids)\n", " return nn.parallel.data_parallel(self.model,input, device_ids=gpuids).view(-1,1)\n", "netg=ModelG(opt.gpuids)\n", "netg.apply(weight_init)\n", "netd=ModelD(opt.gpuids)\n", "netd.apply(weight_init)" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 优化器\n", "optimizerD=optim.Adam(netd.parameters(),lr=opt.lr,betas=(opt.beta1,0.999))\n", "optimizerG=optim.Adam(netg.parameters(),lr=opt.lr,betas=(opt.beta1,0.999))\n", "\n", "# 模型的输入输出\n", "input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2))\n", "label=Variable(t.FloatTensor(opt.batch_size))\n", "noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1))\n", "fixed_noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))\n", "real_label=1\n", "fake_label=0\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "0/0 lossD:1.36829459667,lossG:0.894766509533,0.486646070145,0.411444487981,0.504189566709\n", "1/0 lossD:1.27741086483,lossG:0.946315646172,0.443847945891,0.392864650115,0.51032573171\n", "2/0 lossD:1.3073836565,lossG:0.946315646172,0.4495063629,0.392864650115,0.49758704938\n", "3/0 lossD:1.13484323025,lossG:0.978585541248,0.415155397728,0.380839592777,0.554169878364\n", "4/0 lossD:1.16920399666,lossG:0.955828249454,0.436097631231,0.391813380178,0.556802288629\n", "5/0 lossD:1.1938097477,lossG:0.9689848423,0.471667434089,0.387160736136,0.583109146915\n", "6/0 lossD:1.21595573425,lossG:0.9689848423,0.451340062544,0.387160736136,0.552521592937\n", "7/0 lossD:1.10636794567,lossG:1.00885415077,0.392565631308,0.368035835214,0.551878349856\n", "8/0 lossD:1.1512260437,lossG:1.00885415077,0.468347774819,0.368035835214,0.605279957876\n", "9/0 lossD:1.02875673771,lossG:1.00885415077,0.363269736525,0.368035835214,0.569313969463\n", "10/0 lossD:0.985520362854,lossG:1.17760324478,0.292472077999,0.314572186675,0.534999700263\n", "11/0 lossD:1.20287775993,lossG:1.0366050005,0.521362030879,0.360163018573,0.63543565385\n", "12/0 lossD:1.22502338886,lossG:1.03986740112,0.467908551916,0.357476376463,0.562215656973\n", "13/0 lossD:1.33360874653,lossG:1.02025258541,0.459647334181,0.369051158894,0.496054288\n" ] } ], "source": [ "# 训练\n", "\n", "criterion=nn.BCELoss()\n", "for epoch in xrange(6):\n", " for ii, data in enumerate(dataloader,0):\n", " \n", " #训练 D 网\n", " netd.zero_grad()\n", " #真实图片\n", " real,_=data\n", " input.data.resize_(real.size()).copy_(real)\n", " label.data.resize_(input.size()[0]).fill_(real_label)\n", " output=netd(input)\n", " error_real=criterion(output,label)\n", " error_real.backward()\n", " D_x=output.data.mean()\n", " # 假图片\n", " noise.data.resize_(input.size()[0],opt.nz,1,1 ).normal_(0,1)\n", " fake_pic=netg(noise).detach()\n", " output2=netd(fake_pic)\n", " label.data.fill_(fake_label)\n", " error_fake=criterion(output2,label)\n", " error_fake.backward()\n", " D_x2=output2.data.mean()\n", " error_D=error_real+error_fake\n", " optimizerD.step()\n", " \n", " # 训练 G网 G网和D网训练次数1:2\n", " if t.rand(1)[0]>0.5: \n", " netg.zero_grad()\n", " label.data.fill_(real_label)\n", " noise.data.normal_(0,1)\n", " fake_pic=netg(noise)\n", " output=netd(fake_pic)\n", " error_G=criterion(output,label)\n", " error_G.backward()\n", " optimizerG.step()\n", " D_G_z2=output.data.mean()\n", " \n", " print ('{ii}/{epoch} lossD:{error_D},lossG:{error_G},{D_x2},{D_G_z2},{D_x}'.format(ii=ii,epoch=epoch,\\\n", " error_D=error_D.data[0],error_G=error_G.data[0],\\\n", " D_x2=D_x2,D_G_z2=D_G_z2,D_x=D_x))\n", " if ii%100==0 and ii>0:\n", " fake_u=netg(fixed_noise)\n", " vutil.save_image(fake_u.data,'fake%s.png'%ii)\n", " vutil.save_image(real,'real%s.png' %ii)\n", " " ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "t.save(netd.state_dict(),'1epoch_netd.pth')\n", "t.save(netg.state_dict(),'1epoch_netg.pth')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# WGAN" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch as t\n", "from torch import nn\n", "from torch.autograd import Variable\n", "from torch.optim import Adam\n", "from torchvision import transforms\n", "from torchvision.utils import save_image\n", "from torchvision.datasets import CIFAR10\n", "import numpy as np\n", "from torch import optim\n", "import torchvision.utils as vutil\n", "#from tensorboard_logger import Logger" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "'''\n", "https://zhuanlan.zhihu.com/p/25071913\n", "WGAN 相比于DCGAN 的修改:\n", "1. 判别器最后一层去掉sigmoid # 回归问题,而不是二分类概率\n", "2. 生成器和判别器的loss不取log # Wasserstein 距离\n", "3. 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c #W距离->L连续->数值稳定\n", "4. 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行 #->玄学\n", "\n", "GAN 两大问题的解释:\n", "collapse mode ->KL 散度不对称\n", "数值不稳定 -> KL散度和JS散度优化方向不一样\n", "'''\n", "class Config:\n", " lr=0.0002\n", " nz=100# 噪声维度\n", " image_size=64\n", " image_size2=64\n", " nc=3# 图片三通道\n", " ngf=64 #生成图片\n", " ndf=64 #判别图片\n", " gpuids=None\n", " beta1=0.5\n", " batch_size=32\n", " max_epoch=12# =1 when debug\n", " workers=2\n", " clamp_num=0.01# WGAN 截断大小\n", " \n", "opt=Config()\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# 加载数据\n", "dataset=CIFAR10(root='cifar10//',\\\n", " transform=transforms.Compose(\\\n", " [transforms.Scale(opt.image_size) ,\n", " transforms.ToTensor(),\n", " transforms.Normalize([0.5]*3,[0.5]*3)\n", " ]))\n", "# 什么惰性加载,预加载,多线程,乱序 全都解决 \n", "dataloader=t.utils.data.DataLoader(dataset,opt.batch_size,True,num_workers=opt.workers)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "collapsed": true }, "outputs": [ { "data": { "text/plain": [ "ModelD (\n", " (model): Sequential (\n", " (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (relu1): LeakyReLU (0.2, inplace)\n", " (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n", " (relu2): LeakyReLU (0.2, inplace)\n", " (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n", " (relu3): LeakyReLU (0.2, inplace)\n", " (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n", " (bnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n", " (relu4): LeakyReLU (0.2, inplace)\n", " (conv5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)\n", " )\n", ")" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# 网络结构\n", "\n", "class ModelG(nn.Module):\n", " def __init__(self,ngpu):\n", " super(ModelG,self).__init__()\n", " self.ngpu=ngpu\n", " self.model=nn.Sequential()\n", " self.model.add_module('deconv1',nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False))\n", " self.model.add_module('bnorm1',nn.BatchNorm2d(opt.ngf*8))\n", " self.model.add_module('relu1',nn.ReLU(True))\n", " self.model.add_module('deconv2',nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False))\n", " self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ngf*4))\n", " self.model.add_module('relu2',nn.ReLU(True))\n", " self.model.add_module('deconv3',nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False))\n", " self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ngf*2))\n", " self.model.add_module('relu3',nn.ReLU(True))\n", " self.model.add_module('deconv4',nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False))\n", " self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ngf))\n", " self.model.add_module('relu4',nn.ReLU(True))\n", " self.model.add_module('deconv5',nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False))\n", " self.model.add_module('tanh',nn.Tanh())\n", " def forward(self,input):\n", " gpuids=None\n", " if self.ngpu:\n", " gpuids=range(gpuids)\n", " return nn.parallel.data_parallel(self.model,input, device_ids=gpuids)\n", "\n", "def weight_init(m):\n", " # 参数初始化。 可以改成xavier初始化方法\n", " class_name=m.__class__.__name__\n", " if class_name.find('conv')!=-1:\n", " m.weight.data.normal_(0,0.02)\n", " if class_name.find('norm')!=-1:\n", " m.weight.data.normal_(1.0,0.02)\n", " \n", "class ModelD(nn.Module):\n", " def __init__(self,ngpu):\n", " super(ModelD,self).__init__()\n", " self.ngpu=ngpu\n", " self.model=nn.Sequential()\n", " self.model.add_module('conv1',nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False))\n", " self.model.add_module('relu1',nn.LeakyReLU(0.2,inplace=True))\n", " \n", " self.model.add_module('conv2',nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False))\n", " self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ndf*2))\n", " self.model.add_module('relu2',nn.LeakyReLU(0.2,inplace=True))\n", " \n", " self.model.add_module('conv3',nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False))\n", " self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ndf*4))\n", " self.model.add_module('relu3',nn.LeakyReLU(0.2,inplace=True))\n", " \n", " self.model.add_module('conv4',nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False))\n", " self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ndf*8))\n", " self.model.add_module('relu4',nn.LeakyReLU(0.2,inplace=True))\n", " \n", " self.model.add_module('conv5',nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False))\n", " # modify: remove sigmoid\n", " #self.model.add_module('sigmoid',nn.Sigmoid())\n", " def forward(self,input):\n", " gpuids=None\n", " if self.ngpu:\n", " gpuids=range(gpuids)\n", " return nn.parallel.data_parallel(self.model,input, device_ids=gpuids).view(-1,1).mean(0).view(1)#\n", " ## no loss but score\n", " \n", "netg=ModelG(opt.gpuids)\n", "netg.apply(weight_init)\n", "netd=ModelD(opt.gpuids)\n", "netd.apply(weight_init)" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "collapsed": true }, "outputs": [], "source": [ "# 定义优化器\n", "optimizerD=optim.RMSprop(netd.parameters(),lr=opt.lr ) #modify : 不要采用基于动量的优化方法 如Adam\n", "optimizerG=optim.RMSprop(netg.parameters(),lr=opt.lr ) # \n", "\n", "# 定义 D网和G网的输入\n", "input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2))\n", "label=Variable(t.FloatTensor(opt.batch_size))\n", "noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1))\n", "fixed_noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))\n", "real_label=1\n", "fake_label=0\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "#criterion=nn.BCELoss() # WGAN 不需要log(交叉熵) \n", "one=t.FloatTensor([1])\n", "mone=-1*one\n", "\n", "#开始训练\n", "for epoch in xrange(opt.max_epoch):\n", " for ii, data in enumerate(dataloader,0):\n", " #### 训练D网 ####\n", " netd.zero_grad() #有必要\n", " real,_=data\n", " input.data.resize_(real.size()).copy_(real)\n", " label.data.resize_(input.size()[0]).fill_(real_label)\n", " output=netd(input)\n", " output.backward(one)#######for wgan\n", " D_x=output.data.mean()\n", " \n", " noise.data.resize_(input.size()[0],opt.nz,1,1 ).normal_(0,1)\n", " fake_pic=netg(noise).detach()\n", " output2=netd(fake_pic)\n", " label.data.fill_(fake_label)\n", " output2.backward(mone) #for wgan\n", " D_x2=output2.data.mean() \n", " optimizerD.step()\n", " for parm in netd.parameters():parm.data.clamp_(-opt.clamp_num,opt.clamp_num) ### 只有判别器需要 截断参数\n", " \n", " #### 训练G网 ########\n", " if t.rand(1)[0]>0.8:\n", " # d网和g网的训练次数不一样, 这里d网和g网的训练比例大概是: 5:1\n", " netg.zero_grad()\n", " label.data.fill_(real_label)\n", " noise.data.normal_(0,1)\n", " fake_pic=netg(noise)\n", " output=netd(fake_pic)\n", " output.backward(one)\n", " optimizerG.step()\n", " #for parm in netg.parameters():parm.data.clamp_(-opt.clamp_num,opt.clamp_num)## 只有判别器需要 生成器不需要\n", " D_G_z2=output.data.mean()\n", "\n", " if ii%100==0 and ii>0:\n", " fake_u=netg(fixed_noise)\n", " vutil.save_image(fake_u.data,'wgan/fake%s_%s.png'%(epoch,ii))\n", " vutil.save_image(real,'wgan/real%s_%s.png'%(epoch,ii)) " ] }, { "cell_type": "code", "execution_count": 85, "metadata": {}, "outputs": [], "source": [ "t.save(netd.state_dict(),'epoch_netd.pth')\n", "t.save(netg.state_dict(),'epoch_netg.pth')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.5.2" }, "toc": { "colors": { "hover_highlight": "#DAA520", "navigate_num": "#000000", "navigate_text": "#333333", "running_highlight": "#FF0000", "selected_highlight": "#FFD700", "sidebar_border": "#EEEEEE", "wrapper_background": "#FFFFFF" }, "moveMenuLeft": true, "nav_menu": { "height": "30px", "width": "252px" }, "navigate_menu": true, "number_sections": true, "sideBar": true, "threshold": 4, "toc_cell": false, "toc_section_display": "block", "toc_window_display": false, "widenNotebook": false } }, "nbformat": 4, "nbformat_minor": 2 }