{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# DCGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "from torch.optim import Adam\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image\n",
    "from torchvision.datasets import CIFAR10\n",
    "import numpy as np\n",
    "from torch  import optim\n",
    "import torchvision.utils as vutil"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Config:\n",
    "    lr=0.0002\n",
    "    nz=100# 噪声维度\n",
    "    image_size=64\n",
    "    image_size2=64\n",
    "    nc=3# 图片三通道\n",
    "    ngf=64 #生成图片\n",
    "    ndf=64 #判别图片\n",
    "    gpuids=None\n",
    "    beta1=0.5\n",
    "    batch_size=32\n",
    "    max_epoch=1# =1 when debug\n",
    "    workers=2\n",
    "    \n",
    "opt=Config()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 数据加载和预处理 \n",
    "dataset=CIFAR10(root='cifar10/',\\\n",
    "                transform=transforms.Compose(\\\n",
    "                                             [transforms.Scale(opt.image_size) ,\n",
    "                                              transforms.ToTensor(),\n",
    "                                              transforms.Normalize([0.5]*3,[0.5]*3)\n",
    "                                              \n",
    "                                             ]))\n",
    "# 什么惰性加载,预加载,多线程,乱序  全都解决 \n",
    "dataloader=t.utils.data.DataLoader(dataset,opt.batch_size,True,num_workers=opt.workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModelD (\n",
       "  (model): Sequential (\n",
       "    (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (relu1): LeakyReLU (0.2, inplace)\n",
       "    (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
       "    (relu2): LeakyReLU (0.2, inplace)\n",
       "    (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
       "    (relu3): LeakyReLU (0.2, inplace)\n",
       "    (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
       "    (relu4): LeakyReLU (0.2, inplace)\n",
       "    (conv5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)\n",
       "    (sigmoid): Sigmoid ()\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#模型定义\n",
    "class ModelG(nn.Module):\n",
    "    def __init__(self,ngpu):\n",
    "        super(ModelG,self).__init__()\n",
    "        self.ngpu=ngpu\n",
    "        self.model=nn.Sequential()\n",
    "        self.model.add_module('deconv1',nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False))\n",
    "        self.model.add_module('bnorm1',nn.BatchNorm2d(opt.ngf*8))\n",
    "        self.model.add_module('relu1',nn.ReLU(True))\n",
    "        self.model.add_module('deconv2',nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ngf*4))\n",
    "        self.model.add_module('relu2',nn.ReLU(True))\n",
    "        self.model.add_module('deconv3',nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ngf*2))\n",
    "        self.model.add_module('relu3',nn.ReLU(True))\n",
    "        \n",
    "        self.model.add_module('deconv4',nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ngf))\n",
    "        self.model.add_module('relu4',nn.ReLU(True))\n",
    "        \n",
    "        self.model.add_module('deconv5',nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False))\n",
    "        self.model.add_module('tanh',nn.Tanh())\n",
    "    def forward(self,input):\n",
    "        gpuids=None\n",
    "        if self.ngpu:\n",
    "            gpuids=range(gpuids)\n",
    "        return nn.parallel.data_parallel(self.model,input, device_ids=gpuids)\n",
    "\n",
    "def weight_init(m):\n",
    "    #模型参数初始化. 可以优化成为xavier 初始化\n",
    "    class_name=m.__class__.__name__\n",
    "    if class_name.find('conv')!=-1:\n",
    "        m.weight.data.normal_(0,0.02)\n",
    "    if class_name.find('norm')!=-1:\n",
    "        m.weight.data.normal_(1.0,0.02)\n",
    "    \n",
    "class ModelD(nn.Module):\n",
    "    def __init__(self,ngpu):\n",
    "        super(ModelD,self).__init__()\n",
    "        self.ngpu=ngpu\n",
    "        self.model=nn.Sequential()\n",
    "        self.model.add_module('conv1',nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False))\n",
    "        self.model.add_module('relu1',nn.LeakyReLU(0.2,inplace=True))\n",
    "        \n",
    "        self.model.add_module('conv2',nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ndf*2))\n",
    "        self.model.add_module('relu2',nn.LeakyReLU(0.2,inplace=True))\n",
    "        \n",
    "        self.model.add_module('conv3',nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ndf*4))\n",
    "        self.model.add_module('relu3',nn.LeakyReLU(0.2,inplace=True))\n",
    "        \n",
    "        self.model.add_module('conv4',nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ndf*8))\n",
    "        self.model.add_module('relu4',nn.LeakyReLU(0.2,inplace=True))\n",
    "        \n",
    "        self.model.add_module('conv5',nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False))\n",
    "        self.model.add_module('sigmoid',nn.Sigmoid())\n",
    "    def forward(self,input):\n",
    "        gpuids=None\n",
    "        if self.ngpu:\n",
    "            gpuids=range(gpuids)\n",
    "        return nn.parallel.data_parallel(self.model,input, device_ids=gpuids).view(-1,1)\n",
    "netg=ModelG(opt.gpuids)\n",
    "netg.apply(weight_init)\n",
    "netd=ModelD(opt.gpuids)\n",
    "netd.apply(weight_init)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 优化器\n",
    "optimizerD=optim.Adam(netd.parameters(),lr=opt.lr,betas=(opt.beta1,0.999))\n",
    "optimizerG=optim.Adam(netg.parameters(),lr=opt.lr,betas=(opt.beta1,0.999))\n",
    "\n",
    "# 模型的输入输出\n",
    "input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2))\n",
    "label=Variable(t.FloatTensor(opt.batch_size))\n",
    "noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1))\n",
    "fixed_noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))\n",
    "real_label=1\n",
    "fake_label=0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "0/0     lossD:1.36829459667,lossG:0.894766509533,0.486646070145,0.411444487981,0.504189566709\n",
      "1/0     lossD:1.27741086483,lossG:0.946315646172,0.443847945891,0.392864650115,0.51032573171\n",
      "2/0     lossD:1.3073836565,lossG:0.946315646172,0.4495063629,0.392864650115,0.49758704938\n",
      "3/0     lossD:1.13484323025,lossG:0.978585541248,0.415155397728,0.380839592777,0.554169878364\n",
      "4/0     lossD:1.16920399666,lossG:0.955828249454,0.436097631231,0.391813380178,0.556802288629\n",
      "5/0     lossD:1.1938097477,lossG:0.9689848423,0.471667434089,0.387160736136,0.583109146915\n",
      "6/0     lossD:1.21595573425,lossG:0.9689848423,0.451340062544,0.387160736136,0.552521592937\n",
      "7/0     lossD:1.10636794567,lossG:1.00885415077,0.392565631308,0.368035835214,0.551878349856\n",
      "8/0     lossD:1.1512260437,lossG:1.00885415077,0.468347774819,0.368035835214,0.605279957876\n",
      "9/0     lossD:1.02875673771,lossG:1.00885415077,0.363269736525,0.368035835214,0.569313969463\n",
      "10/0     lossD:0.985520362854,lossG:1.17760324478,0.292472077999,0.314572186675,0.534999700263\n",
      "11/0     lossD:1.20287775993,lossG:1.0366050005,0.521362030879,0.360163018573,0.63543565385\n",
      "12/0     lossD:1.22502338886,lossG:1.03986740112,0.467908551916,0.357476376463,0.562215656973\n",
      "13/0     lossD:1.33360874653,lossG:1.02025258541,0.459647334181,0.369051158894,0.496054288\n"
     ]
    }
   ],
   "source": [
    "# 训练\n",
    "\n",
    "criterion=nn.BCELoss()\n",
    "for epoch in xrange(6):\n",
    "    for ii, data in enumerate(dataloader,0):\n",
    "        \n",
    "        #训练 D 网\n",
    "        netd.zero_grad()\n",
    "        #真实图片\n",
    "        real,_=data\n",
    "        input.data.resize_(real.size()).copy_(real)\n",
    "        label.data.resize_(input.size()[0]).fill_(real_label)\n",
    "        output=netd(input)\n",
    "        error_real=criterion(output,label)\n",
    "        error_real.backward()\n",
    "        D_x=output.data.mean()\n",
    "        # 假图片\n",
    "        noise.data.resize_(input.size()[0],opt.nz,1,1 ).normal_(0,1)\n",
    "        fake_pic=netg(noise).detach()\n",
    "        output2=netd(fake_pic)\n",
    "        label.data.fill_(fake_label)\n",
    "        error_fake=criterion(output2,label)\n",
    "        error_fake.backward()\n",
    "        D_x2=output2.data.mean()\n",
    "        error_D=error_real+error_fake\n",
    "        optimizerD.step()\n",
    "        \n",
    "        # 训练 G网  G网和D网训练次数1:2\n",
    "        if t.rand(1)[0]>0.5:   \n",
    "            netg.zero_grad()\n",
    "            label.data.fill_(real_label)\n",
    "            noise.data.normal_(0,1)\n",
    "            fake_pic=netg(noise)\n",
    "            output=netd(fake_pic)\n",
    "            error_G=criterion(output,label)\n",
    "            error_G.backward()\n",
    "            optimizerG.step()\n",
    "            D_G_z2=output.data.mean()\n",
    "        \n",
    "        print ('{ii}/{epoch}     lossD:{error_D},lossG:{error_G},{D_x2},{D_G_z2},{D_x}'.format(ii=ii,epoch=epoch,\\\n",
    "                error_D=error_D.data[0],error_G=error_G.data[0],\\\n",
    "                D_x2=D_x2,D_G_z2=D_G_z2,D_x=D_x))\n",
    "        if ii%100==0 and ii>0:\n",
    "            fake_u=netg(fixed_noise)\n",
    "            vutil.save_image(fake_u.data,'fake%s.png'%ii)\n",
    "            vutil.save_image(real,'real%s.png' %ii)\n",
    "          "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "t.save(netd.state_dict(),'1epoch_netd.pth')\n",
    "t.save(netg.state_dict(),'1epoch_netg.pth')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# WGAN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch as t\n",
    "from torch import nn\n",
    "from torch.autograd import Variable\n",
    "from torch.optim import Adam\n",
    "from torchvision import transforms\n",
    "from torchvision.utils import save_image\n",
    "from torchvision.datasets import CIFAR10\n",
    "import numpy as np\n",
    "from torch  import optim\n",
    "import torchvision.utils as vutil\n",
    "#from tensorboard_logger import Logger"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "'''\n",
    "https://zhuanlan.zhihu.com/p/25071913\n",
    "WGAN 相比于DCGAN 的修改:\n",
    "1. 判别器最后一层去掉sigmoid                                       # 回归问题,而不是二分类概率\n",
    "2. 生成器和判别器的loss不取log                                      # Wasserstein 距离\n",
    "3. 每次更新判别器的参数之后把它们的绝对值截断到不超过一个固定常数c          #W距离->L连续->数值稳定\n",
    "4. 不要用基于动量的优化算法(包括momentum和Adam),推荐RMSProp,SGD也行  #->玄学\n",
    "\n",
    "GAN 两大问题的解释:\n",
    "collapse mode ->KL 散度不对称\n",
    "数值不稳定 -> KL散度和JS散度优化方向不一样\n",
    "'''\n",
    "class Config:\n",
    "    lr=0.0002\n",
    "    nz=100# 噪声维度\n",
    "    image_size=64\n",
    "    image_size2=64\n",
    "    nc=3# 图片三通道\n",
    "    ngf=64 #生成图片\n",
    "    ndf=64 #判别图片\n",
    "    gpuids=None\n",
    "    beta1=0.5\n",
    "    batch_size=32\n",
    "    max_epoch=12# =1 when debug\n",
    "    workers=2\n",
    "    clamp_num=0.01# WGAN 截断大小\n",
    "    \n",
    "opt=Config()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 加载数据\n",
    "dataset=CIFAR10(root='cifar10//',\\\n",
    "                transform=transforms.Compose(\\\n",
    "                                             [transforms.Scale(opt.image_size) ,\n",
    "                                              transforms.ToTensor(),\n",
    "                                              transforms.Normalize([0.5]*3,[0.5]*3)\n",
    "                                             ]))\n",
    "# 什么惰性加载,预加载,多线程,乱序  全都解决 \n",
    "dataloader=t.utils.data.DataLoader(dataset,opt.batch_size,True,num_workers=opt.workers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ModelD (\n",
       "  (model): Sequential (\n",
       "    (conv1): Conv2d(3, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (relu1): LeakyReLU (0.2, inplace)\n",
       "    (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bnorm2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)\n",
       "    (relu2): LeakyReLU (0.2, inplace)\n",
       "    (conv3): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bnorm3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True)\n",
       "    (relu3): LeakyReLU (0.2, inplace)\n",
       "    (conv4): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "    (bnorm4): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True)\n",
       "    (relu4): LeakyReLU (0.2, inplace)\n",
       "    (conv5): Conv2d(512, 1, kernel_size=(4, 4), stride=(1, 1), bias=False)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 网络结构\n",
    "\n",
    "class ModelG(nn.Module):\n",
    "    def __init__(self,ngpu):\n",
    "        super(ModelG,self).__init__()\n",
    "        self.ngpu=ngpu\n",
    "        self.model=nn.Sequential()\n",
    "        self.model.add_module('deconv1',nn.ConvTranspose2d(opt.nz,opt.ngf*8,4,1,0,bias=False))\n",
    "        self.model.add_module('bnorm1',nn.BatchNorm2d(opt.ngf*8))\n",
    "        self.model.add_module('relu1',nn.ReLU(True))\n",
    "        self.model.add_module('deconv2',nn.ConvTranspose2d(opt.ngf*8,opt.ngf*4,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ngf*4))\n",
    "        self.model.add_module('relu2',nn.ReLU(True))\n",
    "        self.model.add_module('deconv3',nn.ConvTranspose2d(opt.ngf*4,opt.ngf*2,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ngf*2))\n",
    "        self.model.add_module('relu3',nn.ReLU(True))\n",
    "        self.model.add_module('deconv4',nn.ConvTranspose2d(opt.ngf*2,opt.ngf,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ngf))\n",
    "        self.model.add_module('relu4',nn.ReLU(True))\n",
    "        self.model.add_module('deconv5',nn.ConvTranspose2d(opt.ngf,opt.nc,4,2,1,bias=False))\n",
    "        self.model.add_module('tanh',nn.Tanh())\n",
    "    def forward(self,input):\n",
    "        gpuids=None\n",
    "        if self.ngpu:\n",
    "            gpuids=range(gpuids)\n",
    "        return nn.parallel.data_parallel(self.model,input, device_ids=gpuids)\n",
    "\n",
    "def weight_init(m):\n",
    "    # 参数初始化。 可以改成xavier初始化方法\n",
    "    class_name=m.__class__.__name__\n",
    "    if class_name.find('conv')!=-1:\n",
    "        m.weight.data.normal_(0,0.02)\n",
    "    if class_name.find('norm')!=-1:\n",
    "        m.weight.data.normal_(1.0,0.02)\n",
    "    \n",
    "class ModelD(nn.Module):\n",
    "    def __init__(self,ngpu):\n",
    "        super(ModelD,self).__init__()\n",
    "        self.ngpu=ngpu\n",
    "        self.model=nn.Sequential()\n",
    "        self.model.add_module('conv1',nn.Conv2d(opt.nc,opt.ndf,4,2,1,bias=False))\n",
    "        self.model.add_module('relu1',nn.LeakyReLU(0.2,inplace=True))\n",
    "        \n",
    "        self.model.add_module('conv2',nn.Conv2d(opt.ndf,opt.ndf*2,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm2',nn.BatchNorm2d(opt.ndf*2))\n",
    "        self.model.add_module('relu2',nn.LeakyReLU(0.2,inplace=True))\n",
    "        \n",
    "        self.model.add_module('conv3',nn.Conv2d(opt.ndf*2,opt.ndf*4,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm3',nn.BatchNorm2d(opt.ndf*4))\n",
    "        self.model.add_module('relu3',nn.LeakyReLU(0.2,inplace=True))\n",
    "        \n",
    "        self.model.add_module('conv4',nn.Conv2d(opt.ndf*4,opt.ndf*8,4,2,1,bias=False))\n",
    "        self.model.add_module('bnorm4',nn.BatchNorm2d(opt.ndf*8))\n",
    "        self.model.add_module('relu4',nn.LeakyReLU(0.2,inplace=True))\n",
    "        \n",
    "        self.model.add_module('conv5',nn.Conv2d(opt.ndf*8,1,4,1,0,bias=False))\n",
    "        # modify: remove sigmoid\n",
    "        #self.model.add_module('sigmoid',nn.Sigmoid())\n",
    "    def forward(self,input):\n",
    "        gpuids=None\n",
    "        if self.ngpu:\n",
    "            gpuids=range(gpuids)\n",
    "        return nn.parallel.data_parallel(self.model,input, device_ids=gpuids).view(-1,1).mean(0).view(1)#\n",
    "         ## no loss but score\n",
    "        \n",
    "netg=ModelG(opt.gpuids)\n",
    "netg.apply(weight_init)\n",
    "netd=ModelD(opt.gpuids)\n",
    "netd.apply(weight_init)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 定义优化器\n",
    "optimizerD=optim.RMSprop(netd.parameters(),lr=opt.lr ) #modify : 不要采用基于动量的优化方法 如Adam\n",
    "optimizerG=optim.RMSprop(netg.parameters(),lr=opt.lr )  #  \n",
    "\n",
    "# 定义 D网和G网的输入\n",
    "input=Variable(t.FloatTensor(opt.batch_size,opt.nc,opt.image_size,opt.image_size2))\n",
    "label=Variable(t.FloatTensor(opt.batch_size))\n",
    "noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1))\n",
    "fixed_noise=Variable(t.FloatTensor(opt.batch_size,opt.nz,1,1).normal_(0,1))\n",
    "real_label=1\n",
    "fake_label=0\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "#criterion=nn.BCELoss() # WGAN 不需要log(交叉熵) \n",
    "one=t.FloatTensor([1])\n",
    "mone=-1*one\n",
    "\n",
    "#开始训练\n",
    "for epoch in xrange(opt.max_epoch):\n",
    "    for ii, data in enumerate(dataloader,0):\n",
    "        #### 训练D网 ####\n",
    "        netd.zero_grad() #有必要\n",
    "        real,_=data\n",
    "        input.data.resize_(real.size()).copy_(real)\n",
    "        label.data.resize_(input.size()[0]).fill_(real_label)\n",
    "        output=netd(input)\n",
    "        output.backward(one)#######for wgan\n",
    "        D_x=output.data.mean()\n",
    "        \n",
    "        noise.data.resize_(input.size()[0],opt.nz,1,1 ).normal_(0,1)\n",
    "        fake_pic=netg(noise).detach()\n",
    "        output2=netd(fake_pic)\n",
    "        label.data.fill_(fake_label)\n",
    "        output2.backward(mone) #for wgan\n",
    "        D_x2=output2.data.mean()        \n",
    "        optimizerD.step()\n",
    "        for parm in netd.parameters():parm.data.clamp_(-opt.clamp_num,opt.clamp_num) ### 只有判别器需要 截断参数\n",
    "        \n",
    "        #### 训练G网 ########\n",
    "        if t.rand(1)[0]>0.8:\n",
    "            # d网和g网的训练次数不一样, 这里d网和g网的训练比例大概是: 5:1\n",
    "            netg.zero_grad()\n",
    "            label.data.fill_(real_label)\n",
    "            noise.data.normal_(0,1)\n",
    "            fake_pic=netg(noise)\n",
    "            output=netd(fake_pic)\n",
    "            output.backward(one)\n",
    "            optimizerG.step()\n",
    "            #for parm in netg.parameters():parm.data.clamp_(-opt.clamp_num,opt.clamp_num)## 只有判别器需要 生成器不需要\n",
    "            D_G_z2=output.data.mean()\n",
    "\n",
    "        if ii%100==0 and ii>0:\n",
    "            fake_u=netg(fixed_noise)\n",
    "            vutil.save_image(fake_u.data,'wgan/fake%s_%s.png'%(epoch,ii))\n",
    "            vutil.save_image(real,'wgan/real%s_%s.png'%(epoch,ii))              "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 85,
   "metadata": {},
   "outputs": [],
   "source": [
    "t.save(netd.state_dict(),'epoch_netd.pth')\n",
    "t.save(netg.state_dict(),'epoch_netg.pth')"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  },
  "toc": {
   "colors": {
    "hover_highlight": "#DAA520",
    "navigate_num": "#000000",
    "navigate_text": "#333333",
    "running_highlight": "#FF0000",
    "selected_highlight": "#FFD700",
    "sidebar_border": "#EEEEEE",
    "wrapper_background": "#FFFFFF"
   },
   "moveMenuLeft": true,
   "nav_menu": {
    "height": "30px",
    "width": "252px"
   },
   "navigate_menu": true,
   "number_sections": true,
   "sideBar": true,
   "threshold": 4,
   "toc_cell": false,
   "toc_section_display": "block",
   "toc_window_display": false,
   "widenNotebook": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}