深度卷积生成对抗网络DCGAN——生成手写数字图片

举报
一颗小树x 发表于 2021/06/30 18:00:17 2021/06/30
【摘要】 ​前言本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:# 用于生成 GIF 图片pip install -q imageio​​一、什么是生成对抗网络?生成对抗网络(GAN),包含生成...

前言

本文使用深度卷积生成对抗网络(DCGAN)生成手写数字图片,代码使用Keras API与tf.GradientTape 编写的,其中tf.GradientTrape是训练模型时用到的。

 本文用到imageio 库来生成gif图片,如果没有安装的,需要安装下:

# 用于生成 GIF 图片
pip install -q imageio

一、什么是生成对抗网络?

生成对抗网络(GAN),包含生成器和判别器,两个模型通过对抗过程同时训练。

生成器,可以理解为“艺术家、创造者”,它学习创造看起来真实的图像。

判别器,可以理解为“艺术评论家、审核者”,它学习区分真假图像。

训练过程中,生成器在生成逼真图像方便逐渐变强,而判别器在辨别这些图像的能力上逐渐变强。

当判别器不能再区分真实图片和伪造图片时,训练过程达到平衡。


本文,在MNIST数据集上演示了该过程。随着训练的进行,生成器所生成的一系列图片,越来越像真实的手写数字。


二、加载数据集

使用MNIST数据,来训练生成器和判别器。生成器将生成类似于MNIST数据集的手写数字。

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内

BUFFER_SIZE = 60000
BATCH_SIZE = 256

# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)


三、创建模型

主要创建两个模型,一个是生成器,另一个是判别器


3.1 生成器

生成器使用 tf.keras.layers.Conv2DTranspose 层,来从随机噪声中产生图片。

然后把从随机噪声中产生图片,作为输入数据,输入到Dense层,开始。

后面,经过多次上采样,达到所预期 28x28x1 的图片尺寸。

def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制

    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())

    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)

    return model

用tf.keras.utils.plot_model( ),看一下模型结构

 

用summary(),看一下模型结构和参数

使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。

generator = make_generator_model()

noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)

plt.imshow(generated_image[0, :, :, 0], cmap='gray')

3.1 判别器

判别器是基于 CNN卷积神经网络 的图片分类器。

def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))

    model.add(layers.Flatten())
    model.add(layers.Dense(1))

    return model

用tf.keras.utils.plot_model( ),看一下模型结构

用summary(),看一下模型结构和参数


四、定义损失函数和优化器

由于有两个模型,一个是生成器,另一个是判别器;所以要分别为两个模型定义损失函数和优化器。

首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。

# 该方法返回计算交叉熵损失的辅助函数
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)


4.1 生成器的损失和优化器

1)生成器损失

生成器损失是量化其欺骗判别器的能力;如果生成器表现良好,判别器将会把伪造图片判断为真实图片(或1)。

这里我们将把判别器在生成图片上的判断结果,与一个值全为1的数组进行对比。

def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)

2)生成器优化器

generator_optimizer = tf.keras.optimizers.Adam(1e-4)


4.2 判别器的损失和优化器

1)判别器损失

判别器损失,是量化判断真伪图片的能力。它将判别器对真实图片的预测值,与全值为1的数组进行对比;将判别器对伪造(生成的)图片的预测值,与全值为0的数组进行对比。

def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss

2)判别器优化器

discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)


五、训练模型

5.1 保存检查点

保存检查点,能帮助保存和恢复模型,在长时间训练任务被中断的情况下比较有帮助。

checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)


5.2 定义训练过程

EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16


# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

训练过程中,在生成器接收到一个“随机噪声中产生的图片”作为输入开始。

判别器随后被用于区分真实图片(训练集的)和伪造图片(生成器生成的)。

两个模型都计算损失函数,并且分别计算梯度用于更新生成器与判别器。

# 注意 `tf.function` 的使用
# 该注解使函数被“编译”
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])

    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)

      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)

      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)

    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)

    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))

def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()

    for image_batch in dataset:
      train_step(image_batch)

    # 继续进行时为 GIF 生成图像
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)

    # 每 15 个 epoch 保存一次模型
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)

    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))

  # 最后一个 epoch 结束后生成图片
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)

# 生成与保存图片
def generate_and_save_images(model, epoch, test_input):
  # 注意 training` 设定为 False
  # 因此,所有层都在推理模式下运行(batchnorm)。
  predictions = model(test_input, training=False)

  fig = plt.figure(figsize=(4,4))

  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')

  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()


5.3 训练模型

调用上面定义的train()函数,来同时训练生成器和判别器。

注意,训练GAN可能比较难的;生成器和判别器不能互相压制对方,需要两种达到平衡,它们用相似的学习率训练。

%%time
train(train_dataset, EPOCHS)

在刚开始训练时,生成的图片看起来很像随机噪声,随着训练过程的进行,生成的数字越来越真实。训练大约50轮后,生成器生成的图片看起来很像MNIST数字了。

训练了15轮的效果:

训练了30轮的效果:

训练过程:

恢复最新的检查点

checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))


六、评估模型

这里通过直接查看生成的图片,来看模型的效果。使用训练过程中生成的图片,通过imageio生成动态gif。

# 使用 epoch 数生成单张图片
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))

display_image(EPOCHS)
anim_file = 'dcgan.gif'

with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)

import IPython
if IPython.version_info > (6,2,0,''):
  display.Image(filename=anim_file)

完整代码:

import tensorflow as tf
import glob
import imageio
import matplotlib.pyplot as plt
import numpy as np
import os
import PIL
from tensorflow.keras import layers
import time

from IPython import display

(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()
 
train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')
train_images = (train_images - 127.5) / 127.5 # 将图片标准化到 [-1, 1] 区间内
 
BUFFER_SIZE = 60000
BATCH_SIZE = 256
 
# 批量化和打乱数据
train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)

# 创建模型--生成器
def make_generator_model():
    model = tf.keras.Sequential()
    model.add(layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Reshape((7, 7, 256)))
    assert model.output_shape == (None, 7, 7, 256) # 注意:batch size 没有限制
 
    model.add(layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))
    assert model.output_shape == (None, 7, 7, 128)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))
    assert model.output_shape == (None, 14, 14, 64)
    model.add(layers.BatchNormalization())
    model.add(layers.LeakyReLU())
 
    model.add(layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))
    assert model.output_shape == (None, 28, 28, 1)
 
    return model

# 使用尚未训练的生成器,创建一张图片,这时的图片是随机噪声中产生。
generator = make_generator_model()
 
noise = tf.random.normal([1, 100])
generated_image = generator(noise, training=False)
 
plt.imshow(generated_image[0, :, :, 0], cmap='gray')
tf.keras.utils.plot_model(generator)

# 判别器
def make_discriminator_model():
    model = tf.keras.Sequential()
    model.add(layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same',
                                     input_shape=[28, 28, 1]))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
 
    model.add(layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))
    model.add(layers.LeakyReLU())
    model.add(layers.Dropout(0.3))
 
    model.add(layers.Flatten())
    model.add(layers.Dense(1))
 
    return model

# 使用(尚未训练的)判别器来对图片的真伪进行判断。模型将被训练为为真实图片输出正值,为伪造图片输出负值。
discriminator = make_discriminator_model()
decision = discriminator(generated_image)
print (decision)

# 首先定义一个辅助函数,用于计算交叉熵损失的,这个两个模型通用。
cross_entropy = tf.keras.losses.BinaryCrossentropy(from_logits=True)

# 生成器的损失和优化器
def generator_loss(fake_output):
    return cross_entropy(tf.ones_like(fake_output), fake_output)
generator_optimizer = tf.keras.optimizers.Adam(1e-4)

# 判别器的损失和优化器
def discriminator_loss(real_output, fake_output):
    real_loss = cross_entropy(tf.ones_like(real_output), real_output)
    fake_loss = cross_entropy(tf.zeros_like(fake_output), fake_output)
    total_loss = real_loss + fake_loss
    return total_loss
discriminator_optimizer = tf.keras.optimizers.Adam(1e-4)

# 保存检查点
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(generator_optimizer=generator_optimizer,
                                 discriminator_optimizer=discriminator_optimizer,
                                 generator=generator,
                                 discriminator=discriminator)

# 定义训练过程
EPOCHS = 50
noise_dim = 100
num_examples_to_generate = 16
 
# 我们将重复使用该种子(因此在动画 GIF 中更容易可视化进度)
seed = tf.random.normal([num_examples_to_generate, noise_dim])

# 注意 `tf.function` 的使用
# 该注解使函数被“编译”
@tf.function
def train_step(images):
    noise = tf.random.normal([BATCH_SIZE, noise_dim])
 
    with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:
      generated_images = generator(noise, training=True)
 
      real_output = discriminator(images, training=True)
      fake_output = discriminator(generated_images, training=True)
 
      gen_loss = generator_loss(fake_output)
      disc_loss = discriminator_loss(real_output, fake_output)
 
    gradients_of_generator = gen_tape.gradient(gen_loss, generator.trainable_variables)
    gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.trainable_variables)
 
    generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.trainable_variables))
    discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.trainable_variables))
 
def train(dataset, epochs):
  for epoch in range(epochs):
    start = time.time()
 
    for image_batch in dataset:
      train_step(image_batch)
 
    # 继续进行时为 GIF 生成图像
    display.clear_output(wait=True)
    generate_and_save_images(generator,
                             epoch + 1,
                             seed)
 
    # 每 15 个 epoch 保存一次模型
    if (epoch + 1) % 15 == 0:
      checkpoint.save(file_prefix = checkpoint_prefix)
 
    print ('Time for epoch {} is {} sec'.format(epoch + 1, time.time()-start))
 
  # 最后一个 epoch 结束后生成图片
  display.clear_output(wait=True)
  generate_and_save_images(generator,
                           epochs,
                           seed)
 
# 生成与保存图片
def generate_and_save_images(model, epoch, test_input):
  # 注意 training` 设定为 False
  # 因此,所有层都在推理模式下运行(batchnorm)。
  predictions = model(test_input, training=False)
 
  fig = plt.figure(figsize=(4,4))
 
  for i in range(predictions.shape[0]):
      plt.subplot(4, 4, i+1)
      plt.imshow(predictions[i, :, :, 0] * 127.5 + 127.5, cmap='gray')
      plt.axis('off')
 
  plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))
  plt.show()

# 训练模型
train(train_dataset, EPOCHS)

# 恢复最新的检查点
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

# 评估模型
# 使用 epoch 数生成单张图片
def display_image(epoch_no):
  return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))
 
display_image(EPOCHS)

anim_file = 'dcgan.gif'
 
with imageio.get_writer(anim_file, mode='I') as writer:
  filenames = glob.glob('image*.png')
  filenames = sorted(filenames)
  last = -1
  for i,filename in enumerate(filenames):
    frame = 2*(i**0.5)
    if round(frame) > round(last):
      last = frame
    else:
      continue
    image = imageio.imread(filename)
    writer.append_data(image)
  image = imageio.imread(filename)
  writer.append_data(image)
 
import IPython
if IPython.version_info > (6,2,0,''):
  display.Image(filename=anim_file)

参考:https://www.tensorflow.org/tutorials/generative/dcgan

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。