CycleGAN图像风格迁移
【摘要】 CycleGAN是图像转换的代表作,样本数据无需配对即可实现转换。例如斑马转换成马、将模特转换成卡通人物等。CycleGAN特点就是通过一个循环,首先将图像从一个域转换到另一个域,然后,再转回来,如果两次转换都很精准的话,那么,转换后的图像应该与输入的图像基本一致。通过这样的的一个循环,CycleGAN将转换前后图片的配对,类似于有监督学习,提升了转换效果。
CycleGAN
论文地址:https://arxiv.org/pdf/1703.10593.pdf
CycleGAN本质上是两个镜像对称的GAN,构成了一个环形网络。
两个GAN共享两个生成器,并各自带一个判别器,即共有两个判别器和两个生成器。一个单向GAN两个loss,两个即共四个loss。
注意事项:
使用框架**:** PyTorch1.4.0
使用硬件**:** 8 vCPU + 64 GiB + 1 x Tesla V100-PCIE-32GB
运行代码方法**:** 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
JupyterLab的详细用法**:** 请参考《ModelAtrs JupyterLab使用指导》
碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》
1.下载代码和数据集
运行下面代码,进行数据和代码的下载和解压缩
使用horse2zebra数据集,数据位于CycleGAN/datasets/horse2zebra/中
import os
# 数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/CycleGAN.zip
# 解压缩
os.system('unzip CycleGAN.zip -d ./')
os.chdir('./CycleGAN')
--2021-07-14 16:13:08-- https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/CycleGAN.zip
Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172
Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.
Proxy request sent, awaiting response... 200 OK
Length: 486452880 (464M) [application/zip]
Saving to: ‘CycleGAN.zip’
CycleGAN.zip 100%[===================>] 463.92M 248MB/s in 1.9s
2021-07-14 16:13:10 (248 MB/s) - ‘CycleGAN.zip’ saved [486452880/486452880]
2.训练
2.1安装依赖库
!pip install -r requirements.txt
Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple
Requirement already satisfied: torch>=1.4.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.4.0)
Requirement already satisfied: torchvision>=0.5.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.5.0)
Requirement already satisfied: dominate>=2.4.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (2.6.0)
Requirement already satisfied: visdom>=0.1.8.8 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (0.1.8.9)
Requirement already satisfied: pillow>=4.1.1 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (6.2.0)
Requirement already satisfied: numpy in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.20.2)
Requirement already satisfied: six in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.15.0)
Requirement already satisfied: jsonpatch in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.32)
Requirement already satisfied: torchfile in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (0.1.0)
Requirement already satisfied: tornado in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (6.1)
Requirement already satisfied: requests in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.25.1)
Requirement already satisfied: pyzmq in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (22.0.3)
Requirement already satisfied: websocket-client in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.1.0)
Requirement already satisfied: scipy in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.3.2)
Requirement already satisfied: jsonpointer>=1.9 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from jsonpatch->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.1)
Requirement already satisfied: chardet<5,>=3.0.2 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (4.0.0)
Requirement already satisfied: certifi>=2017.4.17 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2020.12.5)
Requirement already satisfied: idna<3,>=2.5 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.10)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.26.4)
[33mWARNING: You are using pip version 20.3.3; however, version 21.1.3 is available.
You should consider upgrading via the '/home/ma-user/anaconda3/envs/PyTorch-1.4/bin/python -m pip install --upgrade pip' command.[0m
2.2开始训练
训练参数可以在CycleGAN/options/train_options.py中查看和修改
如果使用其他数据集,需要修改数据路径
模型命名为horse2zebra
!python train.py --dataroot ./datasets/horse2zebra --name horse2zebra --model cycle_gan
----------------- Options ---------------
batch_size: 1
beta1: 0.5
checkpoints_dir: ./checkpoints
continue_train: False
crop_size: 256
dataroot: ./datasets/horse2zebra [default: None]
dataset_mode: unaligned
direction: AtoB
display_env: main
display_freq: 400
display_id: 1
display_ncols: 4
display_port: 8097
display_server: http://localhost
display_winsize: 256
epoch: latest
epoch_count: 1
gan_mode: lsgan
gpu_ids: 0
init_gain: 0.02
init_type: normal
input_nc: 3
isTrain: True [default: None]
lambda_A: 10.0
lambda_B: 10.0
lambda_identity: 0.5
load_iter: 0 [default: 0]
load_size: 286
lr: 0.0002
lr_decay_iters: 50
lr_policy: linear
max_dataset_size: inf
model: cycle_gan
n_epochs: 1
n_epochs_decay: 1
n_layers_D: 3
name: horse2zebra [default: experiment_name]
ndf: 64
netD: basic
netG: resnet_9blocks
ngf: 64
no_dropout: True
no_flip: False
no_html: False
norm: instance
num_threads: 4
output_nc: 3
phase: train
pool_size: 50
preprocess: resize_and_crop
print_freq: 100
save_by_iter: False
save_epoch_freq: 1
save_latest_freq: 5000
serial_batches: False
suffix:
update_html_freq: 1000
verbose: False
----------------- End -------------------
dataset [UnalignedDataset] was created
The number of training images = 1334
initialize network with normal
initialize network with normal
initialize network with normal
initialize network with normal
model [CycleGANModel] was created
---------- Networks initialized -------------
[Network G_A] Total number of parameters : 11.378 M
[Network G_B] Total number of parameters : 11.378 M
[Network D_A] Total number of parameters : 2.765 M
[Network D_B] Total number of parameters : 2.765 M
-----------------------------------------------
/home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:122: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
learning rate 0.0002000 -> 0.0001000
(epoch: 1, iters: 100, time: 0.205, data: 0.142)
(epoch: 1, iters: 200, time: 0.205, data: 0.001)
(epoch: 1, iters: 300, time: 0.204, data: 0.002)
(epoch: 1, iters: 400, time: 0.204, data: 0.001)
(epoch: 1, iters: 500, time: 0.204, data: 0.002)
(epoch: 1, iters: 600, time: 0.204, data: 0.001)
(epoch: 1, iters: 700, time: 0.204, data: 0.002)
(epoch: 1, iters: 800, time: 0.204, data: 0.002)
(epoch: 1, iters: 900, time: 0.204, data: 0.001)
(epoch: 1, iters: 1000, time: 0.204, data: 0.001)
(epoch: 1, iters: 1100, time: 0.204, data: 0.001)
(epoch: 1, iters: 1200, time: 0.204, data: 0.002)
(epoch: 1, iters: 1300, time: 0.202, data: 0.001)
saving the model at the end of epoch 1, iters 1334
End of epoch 1 / 2 Time Taken: 255 sec
learning rate 0.0001000 -> 0.0000000
(epoch: 2, iters: 66, time: 0.204, data: 0.001)
(epoch: 2, iters: 166, time: 0.204, data: 0.002)
(epoch: 2, iters: 266, time: 0.203, data: 0.002)
(epoch: 2, iters: 366, time: 0.204, data: 0.002)
(epoch: 2, iters: 466, time: 0.204, data: 0.001)
(epoch: 2, iters: 566, time: 0.204, data: 0.002)
(epoch: 2, iters: 666, time: 0.204, data: 0.002)
(epoch: 2, iters: 766, time: 0.204, data: 0.002)
(epoch: 2, iters: 866, time: 0.204, data: 0.002)
(epoch: 2, iters: 966, time: 0.204, data: 0.002)
(epoch: 2, iters: 1066, time: 0.202, data: 0.001)
(epoch: 2, iters: 1166, time: 0.204, data: 0.002)
(epoch: 2, iters: 1266, time: 0.204, data: 0.002)
saving the model at the end of epoch 2, iters 2668
End of epoch 2 / 2 Time Taken: 254 sec
3.测试
查看刚才生成的模型horse2zebra是否已经生成,如果生成则会在checkpoints文件下
!ls checkpoints/
horse2zebra horse2zebra_pretrained
进行测试
!python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
----------------- Options ---------------
aspect_ratio: 1.0
batch_size: 1
checkpoints_dir: ./checkpoints
crop_size: 256
dataroot: datasets/horse2zebra/testA [default: None]
dataset_mode: single
direction: AtoB
display_winsize: 256
epoch: latest
eval: False
gpu_ids: 0
init_gain: 0.02
init_type: normal
input_nc: 3
isTrain: False [default: None]
load_iter: 0 [default: 0]
load_size: 256
max_dataset_size: inf
model: test
model_suffix:
n_layers_D: 3
name: horse2zebra_pretrained [default: experiment_name]
ndf: 64
netD: basic
netG: resnet_9blocks
ngf: 64
no_dropout: True [default: False]
no_flip: False
norm: instance
num_test: 50
num_threads: 4
output_nc: 3
phase: test
preprocess: resize_and_crop
results_dir: ./results/
serial_batches: False
suffix:
verbose: False
----------------- End -------------------
dataset [SingleDataset] was created
initialize network with normal
model [TestModel] was created
loading the model from ./checkpoints/horse2zebra_pretrained/latest_net_G.pth
---------- Networks initialized -------------
[Network G] Total number of parameters : 11.378 M
-----------------------------------------------
creating web directory ./results/horse2zebra_pretrained/test_latest
processing (0000)-th image... ['datasets/horse2zebra/testA/n02381460_1000.jpg']
processing (0005)-th image... ['datasets/horse2zebra/testA/n02381460_1110.jpg']
processing (0010)-th image... ['datasets/horse2zebra/testA/n02381460_1260.jpg']
processing (0015)-th image... ['datasets/horse2zebra/testA/n02381460_1420.jpg']
processing (0020)-th image... ['datasets/horse2zebra/testA/n02381460_1690.jpg']
processing (0025)-th image... ['datasets/horse2zebra/testA/n02381460_1830.jpg']
processing (0030)-th image... ['datasets/horse2zebra/testA/n02381460_2050.jpg']
processing (0035)-th image... ['datasets/horse2zebra/testA/n02381460_2460.jpg']
processing (0040)-th image... ['datasets/horse2zebra/testA/n02381460_2870.jpg']
processing (0045)-th image... ['datasets/horse2zebra/testA/n02381460_3040.jpg']
展示测试结果
可以在./results/horse2zebra_pretrained/文件下看到测试生成的结果
import matplotlib.pyplot as plt
img = plt.imread('./results/horse2zebra_pretrained/test_latest/images/n02381460_1010_fake.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f89bdf93bd0>
import matplotlib.pyplot as plt
img = plt.imread('./results/horse2zebra_pretrained/test_latest/images/n02381460_1010_real.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f89826bced0>
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)