【万物皆可 GAN】生成对抗网络生成手写数字 Part 1
【摘要】
【万物皆可 GAN】生成对抗网络生成手写数字 Part 1
概述GAN 网络结构GAN 训练流程模型详解生成器判别器
概述
GAN (Generative Adversarial Ne...
概述
GAN (Generative Adversarial Network) 即生成对抗网络. GAN 网络包括一个生成器 (Generator) 和一个判别器 (Discriminator). GAN 可以自动提取特征, 并判断和优化.
GAN 网络结构
生成器 (Generator) 输入一个向量, 输出手写数字大小的像素图像.
判别器 (Discriminator) 输入图片, 判断图片是来自数据集还是来自生成器的, 输出标签 (Real / Fake)
GAN 训练流程
第一阶段:
- 固定判别器, 训练生成器: 使得生成器的技能不断提升, 骗过判别器
第二阶段:
- 固定生成器, 训练判别器: 使得判别器的技能不断提升, 生成器无法骗过判别器
然后:
- 循环第一阶段和第二阶段, 使得生成器和判别器都越来越强
模型详解
生成器
class Generator(nn.Module):
"""生成器"""
def __init__(self, latent_dim, img_shape):
super(Generator, self).__init__()
def block(in_feat, out_feat, normalize=True):
"""
block
:param in_feat: 输入的特征维度
:param out_feat: 输出的特征维度
:param normalize: 归一化
:return: block
"""
layers = [nn.Linear(in_feat, out_feat)]
# 归一化
if normalize:
layers.append(nn.BatchNorm1d(out_feat, 0.8))
# 激活
layers.append(nn.LeakyReLU(0.2, inplace=True))
return layers
self.model = nn.Sequential(
# [b, 100] => [b, 128]
*block(latent_dim, 128, normalize=False),
# [b, 128] => [b, 256]
*block(128, 256),
# [b, 256] => [b, 512]
*block(256, 512),
# [b, 512] => [b, 1024]
*block(512, 1024),
# [b, 1024] => [b, 28 * 28 * 1] => [b, 784]
nn.Linear(1024, int(np.prod(img_shape))),
# 激活
nn.Tanh()
)
def forward(self, z, img_shape):
# [b, 100] => [b, 784]
img = self.model(z)
# [b, 784] => [b, 1, 28, 28]
img = img.view(img.size(0), *img_shape)
# 返回生成的图片
return img
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
网络结构:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 128] 12,928
LeakyReLU-2 [-1, 128] 0
Linear-3 [-1, 256] 33,024
BatchNorm1d-4 [-1, 256] 512
LeakyReLU-5 [-1, 256] 0
Linear-6 [-1, 512] 131,584
BatchNorm1d-7 [-1, 512] 1,024
LeakyReLU-8 [-1, 512] 0
Linear-9 [-1, 1024] 525,312
BatchNorm1d-10 [-1, 1024] 2,048
LeakyReLU-11 [-1, 1024] 0
Linear-12 [-1, 784] 803,600
Tanh-13 [-1, 784] 0
================================================================
Total params: 1,510,032
Trainable params: 1,510,032
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.05
Params size (MB): 5.76
Estimated Total Size (MB): 5.82
----------------------------------------------------------------
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
判别器
class Discriminator(nn.Module):
"""判断器"""
def __init__(self, img_shape):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
# 就是个线性回归
nn.Linear(int(np.prod(img_shape)), 512),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(512, 256),
nn.LeakyReLU(0.2, inplace=True),
nn.Linear(256, 1),
nn.Sigmoid(),
)
def forward(self, img):
# 压平
img_flat = img.view(img.size(0), -1)
validity = self.model(img_flat)
return validity
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
网络结构:
----------------------------------------------------------------
Layer (type) Output Shape Param #
================================================================
Linear-1 [-1, 512] 401,920
LeakyReLU-2 [-1, 512] 0
Linear-3 [-1, 256] 131,328
LeakyReLU-4 [-1, 256] 0
Linear-5 [-1, 1] 257
Sigmoid-6 [-1, 1] 0
================================================================
Total params: 533,505
Trainable params: 533,505
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.00
Forward/backward pass size (MB): 0.01
Params size (MB): 2.04
Estimated Total Size (MB): 2.05
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
文章来源: iamarookie.blog.csdn.net,作者:我是小白呀,版权归原作者所有,如需转载,请联系作者。
原文链接:iamarookie.blog.csdn.net/article/details/118667375
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)