【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

举报
我是小白呀iamarookie 发表于 2021/09/10 00:08:35 2021/09/10
【摘要】 【万物皆可 GAN】生成对抗网络生成手写数字 Part 1 概述GAN 网络结构GAN 训练流程模型详解生成器判别器 概述 GAN (Generative Adversarial Ne...

【万物皆可 GAN】生成对抗网络生成手写数字 Part 1

概述

GAN (Generative Adversarial Network) 即生成对抗网络. GAN 网络包括一个生成器 (Generator) 和一个判别器 (Discriminator). GAN 可以自动提取特征, 并判断和优化.
在这里插入图片描述

GAN 网络结构

生成器 (Generator) 输入一个向量, 输出手写数字大小的像素图像.

在这里插入图片描述

判别器 (Discriminator) 输入图片, 判断图片是来自数据集还是来自生成器的, 输出标签 (Real / Fake)

GAN 训练流程

在这里插入图片描述
第一阶段:

  • 固定判别器, 训练生成器: 使得生成器的技能不断提升, 骗过判别器

第二阶段:

  • 固定生成器, 训练判别器: 使得判别器的技能不断提升, 生成器无法骗过判别器

然后:

  • 循环第一阶段和第二阶段, 使得生成器和判别器都越来越强

模型详解

生成器

class Generator(nn.Module):
    """生成器"""

    def __init__(self, latent_dim, img_shape):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            """
            block
            :param in_feat: 输入的特征维度
            :param out_feat: 输出的特征维度
            :param normalize: 归一化
            :return: block
            """
            layers = [nn.Linear(in_feat, out_feat)]

            # 归一化
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))

            # 激活
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            # [b, 100] => [b, 128]
            *block(latent_dim, 128, normalize=False),
            # [b, 128] => [b, 256]
            *block(128, 256),
            # [b, 256] => [b, 512]
            *block(256, 512),
            # [b, 512] => [b, 1024]
            *block(512, 1024),
            # [b, 1024] => [b, 28 * 28 * 1] => [b, 784]
            nn.Linear(1024, int(np.prod(img_shape))),
            # 激活
            nn.Tanh()
        )

    def forward(self, z, img_shape):
        # [b, 100] => [b, 784]
        img = self.model(z)
        # [b, 784] => [b, 1, 28, 28]
        img = img.view(img.size(0), *img_shape)
        
        # 返回生成的图片
        return img

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47

网络结构:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                  [-1, 128]          12,928
         LeakyReLU-2                  [-1, 128]               0
            Linear-3                  [-1, 256]          33,024
       BatchNorm1d-4                  [-1, 256]             512
         LeakyReLU-5                  [-1, 256]               0
            Linear-6                  [-1, 512]         131,584
       BatchNorm1d-7                  [-1, 512]           1,024
         LeakyReLU-8                  [-1, 512]               0
            Linear-9                 [-1, 1024]         525,312
      BatchNorm1d-10                 [-1, 1024]           2,048
        LeakyReLU-11                 [-1, 1024]               0
           Linear-12                  [-1, 784]         803,600
             Tanh-13                  [-1, 784]               0
================================================================
Total params: 1,510,032
Trainable params: 1,510,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 5.76
Estimated Total Size (MB): 5.82
----------------------------------------------------------------

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26

判别器

class Discriminator(nn.Module):
    """判断器"""
    
    def __init__(self, img_shape):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            # 就是个线性回归
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid(),
        )

    def forward(self, img):
        # 压平
        img_flat = img.view(img.size(0), -1)

        validity = self.model(img_flat)

        return validity

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23

网络结构:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Linear-1                  [-1, 512]         401,920
         LeakyReLU-2                  [-1, 512]               0
            Linear-3                  [-1, 256]         131,328
         LeakyReLU-4                  [-1, 256]               0
            Linear-5                    [-1, 1]             257
           Sigmoid-6                    [-1, 1]               0
================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.04
Estimated Total Size (MB): 2.05

  
 
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18

文章来源: iamarookie.blog.csdn.net,作者:我是小白呀,版权归原作者所有,如需转载,请联系作者。

原文链接:iamarookie.blog.csdn.net/article/details/118667375

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。