ACGAN-动漫头像自动生成
ACGAN
论文:Conditional Image Synthesis with Auxiliary Classifier GANs
使用标签的数据集应用于生成对抗网络可以增强现有的生成模型,并形成两种优化思路。
- cGAN使用了辅助的标签信息来增强原始GAN,对生成器和判别器都使用标签数据进行训练,从而实现模型具备产生特定条件数据的能力。
- SGAN的结构来利用辅助标签信息(少量标签),利用判别器或者分类器的末端重建标签信息。 ACGAN则是结合以上两种思路对GAN进行优化。
ACGAN目标函数:
对于生成器来说有两个输入,一个是标签的分类数据c,另一个是随机数据z,得到生成数据为 ; 对于判别器分别要判断数据源是否为真实数据的概率分布 ,以及数据源对于分类标签的概率分布
ACGAN的目标函数包含两部分: 第一部分 是面向数据真实与否的代价函数 第二部分 则是数据分类准确性的代价函数。
在优化过程中希望判别器D能否使得 + 尽可能最大,而生成器G使得 - 尽可能最大; 简而言之是希望判别器能够尽可能区分真实数据和生成数据并且能有效对数据进行分类,对生成器来说希望生成数据被尽可能认为是真实数据且数据都能够被有效分类。
1.本案例使用框架:TensorFlow 1.13.1
2.本案例使用硬件:GPU: 1*NVIDIA-V100NV32(32GB) | CPU: 8 核 64GB
3.运行代码方法: 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
4.JupyterLab的详细用法: 请参考《ModelAtrs JupyterLab使用指导》
5.碰到问题的解决办法: 请参考《ModelAtrs JupyterLab常见问题解决办法》
1.下载模型和代码
import os
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/ACGAN.zip
# 解压缩
os.system('unzip ACGAN.zip -d ./')
2.模型训练
2.1加载依赖库
root_path = './ACGAN/'
os.chdir(root_path)
import os
from main import main
from ACGAN import ACGAN
from tools import checkFolder
import tensorflow as tf
import argparse
import numpy as np
2.2设置参数
def parse_args():
note = "ACGAN Frame Constructed With Tensorflow"
parser = argparse.ArgumentParser(description=note)
parser.add_argument("--epoch",type=int,default=251,help="训练轮数")
parser.add_argument("--batchSize",type=int,default=64,help="batch的大小")
parser.add_argument("--codeSize",type=int,default=62,help="输入编码向量的维度")
parser.add_argument("--checkpointDir",type=str,default="./checkpoint",help="检查点保存目录")
parser.add_argument("--resultDir",type=str,default="./result",help="训练过程中,中间生成结果的目录")
parser.add_argument("--logDir",type=str,default="./log",help="训练日志目录")
parser.add_argument("--mode",type=str,default="train",help="模式: train / infer")
parser.add_argument("--hairStyle",type=str,default="orange hair",help="你想要生成的动漫头像的头发颜色")
parser.add_argument("--eyeStyle",type=str,default="gray eyes",help="你想要生成的动漫头像的眼睛颜色")
parser.add_argument("--dataSource",type=str,default='./extra_data/images/',help="训练集路径")
args, unknown= parser.parse_known_args()
checkFolder(args.checkpointDir)
checkFolder(args.resultDir)
checkFolder(args.logDir)
assert args.epoch>=1
assert args.batchSize>=1
assert args.codeSize>=1
return args
args =parse_args()
2.3开始训练
with tf.Session() as sess :
myGAN = ACGAN(sess,args.epoch,args.batchSize,args.codeSize,\
args.dataSource,args.checkpointDir,args.resultDir,args.logDir,args.mode,\
64,64,3)
if myGAN is None:
print("创建GAN网络失败")
exit(0)
if args.mode=='train' :
myGAN.buildNet()
print("进入训练模式")
myGAN.train()
print("Done")
开始加载数据集!
images.shape: (3000, 64, 64, 3)
labels.shape: (3000, 23)
Loading images to numpy array...
Random shuffling images and labels...
[Tip 1] Normalize the images between -1 and 1.
数据集加载成功!
numOfBatches : 46
网络实例化:
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/ops/tensor_array_ops.py:162: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/contrib/layers/python/layers/layers.py:1624: flatten (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.flatten instead.
WARNING:tensorflow:From /home/ma-user/work/ACGAN/ACGAN.py:167: batch_normalization (from tensorflow.python.layers.normalization) is deprecated and will be removed in a future version.
Instructions for updating:
Use keras.layers.batch_normalization instead.
已构建 Loss for Discriminator
已构建 Loss for Generator
# size of dVars : 55
# size of gVars : 148
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
已构建优化器
已构造预测器
网络实例化成功!
进入训练模式
开始配置训练环境!
模型将会被加载 : ./checkpoint/ACGAN
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.
Instructions for updating:
Use standard file APIs to check for files with this prefix.
INFO:tensorflow:Restoring parameters from ./checkpoint/ACGAN/ACGAN.model-251
MODEL NAME : ACGAN.model-251
模型加载成功 : ACGAN.model-251
加载成功
生成模型结果预览
训练开始!~~~~~~~~~~~~~~~~~~~~~~~~
251 251
Done
修改参数从训练模式为推理模式
args.mode ='infer'
从标签里选择你想要生成的头像的头发和眼睛,只能从这两个列表里选择
hair_dict = ['orange hair', 'white hair', 'aqua hair', 'gray hair', 'green hair', 'red hair', 'purple hair',
'pink hair', 'blue hair', 'black hair', 'brown hair', 'blonde hair']
eye_dict = [ 'gray eyes', 'black eyes', 'orange eyes', 'pink eyes', 'yellow eyes',
'aqua eyes', 'purple eyes', 'green eyes', 'brown eyes', 'red eyes', 'blue eyes']
# 选择了黄头发和灰眼睛
args.hairStyle = 'orange hair'
args.eyeStyle = 'gray eyes'
构造预测器
tf.reset_default_graph()
with tf.Session() as sess :
myGAN1 = ACGAN(sess,args.epoch,args.batchSize,args.codeSize,\
args.dataSource,args.checkpointDir,args.resultDir,args.logDir,args.mode,\
64,64,3)
if myGAN1 is None:
print("创建GAN网络失败")
exit(0)
if args.mode=='infer' :
myGAN1.buildForInfer()
tag_dict = ['orange hair', 'white hair', 'aqua hair', 'gray hair', 'green hair', 'red hair', 'purple hair', 'pink hair', 'blue hair', 'black hair',
'brown hair', 'blonde hair','gray eyes', 'black eyes', 'orange eyes', 'pink eyes', 'yellow eyes','aqua eyes', 'purple eyes', 'green eyes',
'brown eyes', 'red eyes','blue eyes']
tag = np.zeros((64,23))
feature = args.hairStyle+" AND "+ args.eyeStyle
for j in range(25):
for i in range(len(tag_dict)):
if tag_dict[i] in feature:
tag[j][i] = 1
myGAN1.infer(tag,feature)
print("Generate : "+feature)
模型将会被加载 : ./checkpoint/ACGAN
INFO:tensorflow:Restoring parameters from ./checkpoint/ACGAN/ACGAN.model-251
MODEL NAME : ACGAN.model-251
模型加载成功 : ACGAN.model-251
已构造预测器
Generate : orange hair AND gray eyes
开始生成黄色头发,灰色眼睛的动漫头像
存在生成不了正确头像的情况
import matplotlib.pyplot as plt
from PIL import Image
feature = args.hairStyle+" AND "+ args.eyeStyle
resultPath = './samples/' + feature + '.png' #确定保存路径
img = Image.open(resultPath).convert('RGB')
plt.figure(1)
plt.imshow(img)
plt.show()
- 点赞
- 收藏
- 关注作者
评论(0)