CycleGAN图像风格迁移

举报
HWCloudAI 发表于 2022/12/01 11:01:47 2022/12/01
【摘要】 CycleGAN是图像转换的代表作,样本数据无需配对即可实现转换。例如斑马转换成马、将模特转换成卡通人物等。CycleGAN特点就是通过一个循环,首先将图像从一个域转换到另一个域,然后,再转回来,如果两次转换都很精准的话,那么,转换后的图像应该与输入的图像基本一致。通过这样的的一个循环,CycleGAN将转换前后图片的配对,类似于有监督学习,提升了转换效果。

CycleGAN

论文地址:https://arxiv.org/pdf/1703.10593.pdf

CycleGAN本质上是两个镜像对称的GAN,构成了一个环形网络。
两个GAN共享两个生成器,并各自带一个判别器,即共有两个判别器和两个生成器。一个单向GAN两个loss,两个即共四个loss。

注意事项:

使用框架**:** PyTorch1.4.0

使用硬件**:** 8 vCPU + 64 GiB + 1 x Tesla V100-PCIE-32GB

运行代码方法**:** 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

JupyterLab的详细用法**:** 请参考《ModelAtrs JupyterLab使用指导》

碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.下载代码和数据集

运行下面代码,进行数据和代码的下载和解压缩

使用horse2zebra数据集,数据位于CycleGAN/datasets/horse2zebra/中

import os
# 数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/CycleGAN.zip
# 解压缩
os.system('unzip CycleGAN.zip -d ./')
os.chdir('./CycleGAN')
--2021-07-14 16:13:08--  https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/CycleGAN.zip

Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172

Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.

Proxy request sent, awaiting response... 200 OK

Length: 486452880 (464M) [application/zip]

Saving to: ‘CycleGAN.zip’



CycleGAN.zip        100%[===================>] 463.92M   248MB/s    in 1.9s    



2021-07-14 16:13:10 (248 MB/s) - ‘CycleGAN.zip’ saved [486452880/486452880]

2.训练

2.1安装依赖库

!pip install -r requirements.txt
Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple

Requirement already satisfied: torch>=1.4.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.4.0)

Requirement already satisfied: torchvision>=0.5.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.5.0)

Requirement already satisfied: dominate>=2.4.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 3)) (2.6.0)

Requirement already satisfied: visdom>=0.1.8.8 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 4)) (0.1.8.9)

Requirement already satisfied: pillow>=4.1.1 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (6.2.0)

Requirement already satisfied: numpy in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.20.2)

Requirement already satisfied: six in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.15.0)

Requirement already satisfied: jsonpatch in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.32)

Requirement already satisfied: torchfile in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (0.1.0)

Requirement already satisfied: tornado in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (6.1)

Requirement already satisfied: requests in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.25.1)

Requirement already satisfied: pyzmq in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (22.0.3)

Requirement already satisfied: websocket-client in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.1.0)

Requirement already satisfied: scipy in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.3.2)

Requirement already satisfied: jsonpointer>=1.9 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from jsonpatch->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.1)

Requirement already satisfied: chardet<5,>=3.0.2 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (4.0.0)

Requirement already satisfied: certifi>=2017.4.17 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2020.12.5)

Requirement already satisfied: idna<3,>=2.5 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.10)

Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.26.4)

WARNING: You are using pip version 20.3.3; however, version 21.1.3 is available.

You should consider upgrading via the '/home/ma-user/anaconda3/envs/PyTorch-1.4/bin/python -m pip install --upgrade pip' command.

2.2开始训练

训练参数可以在CycleGAN/options/train_options.py中查看和修改
如果使用其他数据集,需要修改数据路径
模型命名为horse2zebra

!python train.py --dataroot ./datasets/horse2zebra --name horse2zebra --model cycle_gan
----------------- Options ---------------

               batch_size: 1                             

                    beta1: 0.5                           

          checkpoints_dir: ./checkpoints                 

           continue_train: False                         

                crop_size: 256                           

                 dataroot: ./datasets/horse2zebra        	[default: None]

             dataset_mode: unaligned                     

                direction: AtoB                          

              display_env: main                          

             display_freq: 400                           

               display_id: 1                             

            display_ncols: 4                             

             display_port: 8097                          

           display_server: http://localhost              

          display_winsize: 256                           

                    epoch: latest                        

              epoch_count: 1                             

                 gan_mode: lsgan                         

                  gpu_ids: 0                             

                init_gain: 0.02                          

                init_type: normal                        

                 input_nc: 3                             

                  isTrain: True                          	[default: None]

                 lambda_A: 10.0                          

                 lambda_B: 10.0                          

          lambda_identity: 0.5                           

                load_iter: 0                             	[default: 0]

                load_size: 286                           

                       lr: 0.0002                        

           lr_decay_iters: 50                            

                lr_policy: linear                        

         max_dataset_size: inf                           

                    model: cycle_gan                     

                 n_epochs: 1                             

           n_epochs_decay: 1                             

               n_layers_D: 3                             

                     name: horse2zebra                   	[default: experiment_name]

                      ndf: 64                            

                     netD: basic                         

                     netG: resnet_9blocks                

                      ngf: 64                            

               no_dropout: True                          

                  no_flip: False                         

                  no_html: False                         

                     norm: instance                      

              num_threads: 4                             

                output_nc: 3                             

                    phase: train                         

                pool_size: 50                            

               preprocess: resize_and_crop               

               print_freq: 100                           

             save_by_iter: False                         

          save_epoch_freq: 1                             

         save_latest_freq: 5000                          

           serial_batches: False                         

                   suffix:                               

         update_html_freq: 1000                          

                  verbose: False                         

----------------- End -------------------

dataset [UnalignedDataset] was created

The number of training images = 1334

initialize network with normal

initialize network with normal

initialize network with normal

initialize network with normal

model [CycleGANModel] was created

---------- Networks initialized -------------

[Network G_A] Total number of parameters : 11.378 M

[Network G_B] Total number of parameters : 11.378 M

[Network D_A] Total number of parameters : 2.765 M

[Network D_B] Total number of parameters : 2.765 M

-----------------------------------------------

/home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:122: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

learning rate 0.0002000 -> 0.0001000

(epoch: 1, iters: 100, time: 0.205, data: 0.142) 

(epoch: 1, iters: 200, time: 0.205, data: 0.001) 

(epoch: 1, iters: 300, time: 0.204, data: 0.002) 

(epoch: 1, iters: 400, time: 0.204, data: 0.001) 

(epoch: 1, iters: 500, time: 0.204, data: 0.002) 

(epoch: 1, iters: 600, time: 0.204, data: 0.001) 

(epoch: 1, iters: 700, time: 0.204, data: 0.002) 

(epoch: 1, iters: 800, time: 0.204, data: 0.002) 

(epoch: 1, iters: 900, time: 0.204, data: 0.001) 

(epoch: 1, iters: 1000, time: 0.204, data: 0.001) 

(epoch: 1, iters: 1100, time: 0.204, data: 0.001) 

(epoch: 1, iters: 1200, time: 0.204, data: 0.002) 

(epoch: 1, iters: 1300, time: 0.202, data: 0.001) 

saving the model at the end of epoch 1, iters 1334

End of epoch 1 / 2 	 Time Taken: 255 sec

learning rate 0.0001000 -> 0.0000000

(epoch: 2, iters: 66, time: 0.204, data: 0.001) 

(epoch: 2, iters: 166, time: 0.204, data: 0.002) 

(epoch: 2, iters: 266, time: 0.203, data: 0.002) 

(epoch: 2, iters: 366, time: 0.204, data: 0.002) 

(epoch: 2, iters: 466, time: 0.204, data: 0.001) 

(epoch: 2, iters: 566, time: 0.204, data: 0.002) 

(epoch: 2, iters: 666, time: 0.204, data: 0.002) 

(epoch: 2, iters: 766, time: 0.204, data: 0.002) 

(epoch: 2, iters: 866, time: 0.204, data: 0.002) 

(epoch: 2, iters: 966, time: 0.204, data: 0.002) 

(epoch: 2, iters: 1066, time: 0.202, data: 0.001) 

(epoch: 2, iters: 1166, time: 0.204, data: 0.002) 

(epoch: 2, iters: 1266, time: 0.204, data: 0.002) 

saving the model at the end of epoch 2, iters 2668

End of epoch 2 / 2 	 Time Taken: 254 sec

3.测试

查看刚才生成的模型horse2zebra是否已经生成,如果生成则会在checkpoints文件下

!ls checkpoints/
horse2zebra  horse2zebra_pretrained

进行测试

!python test.py --dataroot datasets/horse2zebra/testA --name horse2zebra_pretrained --model test --no_dropout
----------------- Options ---------------

             aspect_ratio: 1.0                           

               batch_size: 1                             

          checkpoints_dir: ./checkpoints                 

                crop_size: 256                           

                 dataroot: datasets/horse2zebra/testA    	[default: None]

             dataset_mode: single                        

                direction: AtoB                          

          display_winsize: 256                           

                    epoch: latest                        

                     eval: False                         

                  gpu_ids: 0                             

                init_gain: 0.02                          

                init_type: normal                        

                 input_nc: 3                             

                  isTrain: False                         	[default: None]

                load_iter: 0                             	[default: 0]

                load_size: 256                           

         max_dataset_size: inf                           

                    model: test                          

             model_suffix:                               

               n_layers_D: 3                             

                     name: horse2zebra_pretrained        	[default: experiment_name]

                      ndf: 64                            

                     netD: basic                         

                     netG: resnet_9blocks                

                      ngf: 64                            

               no_dropout: True                          	[default: False]

                  no_flip: False                         

                     norm: instance                      

                 num_test: 50                            

              num_threads: 4                             

                output_nc: 3                             

                    phase: test                          

               preprocess: resize_and_crop               

              results_dir: ./results/                    

           serial_batches: False                         

                   suffix:                               

                  verbose: False                         

----------------- End -------------------

dataset [SingleDataset] was created

initialize network with normal

model [TestModel] was created

loading the model from ./checkpoints/horse2zebra_pretrained/latest_net_G.pth

---------- Networks initialized -------------

[Network G] Total number of parameters : 11.378 M

-----------------------------------------------

creating web directory ./results/horse2zebra_pretrained/test_latest

processing (0000)-th image... ['datasets/horse2zebra/testA/n02381460_1000.jpg']

processing (0005)-th image... ['datasets/horse2zebra/testA/n02381460_1110.jpg']

processing (0010)-th image... ['datasets/horse2zebra/testA/n02381460_1260.jpg']

processing (0015)-th image... ['datasets/horse2zebra/testA/n02381460_1420.jpg']

processing (0020)-th image... ['datasets/horse2zebra/testA/n02381460_1690.jpg']

processing (0025)-th image... ['datasets/horse2zebra/testA/n02381460_1830.jpg']

processing (0030)-th image... ['datasets/horse2zebra/testA/n02381460_2050.jpg']

processing (0035)-th image... ['datasets/horse2zebra/testA/n02381460_2460.jpg']

processing (0040)-th image... ['datasets/horse2zebra/testA/n02381460_2870.jpg']

processing (0045)-th image... ['datasets/horse2zebra/testA/n02381460_3040.jpg']

展示测试结果

可以在./results/horse2zebra_pretrained/文件下看到测试生成的结果

import matplotlib.pyplot as plt

img = plt.imread('./results/horse2zebra_pretrained/test_latest/images/n02381460_1010_fake.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f89bdf93bd0>

image.png

import matplotlib.pyplot as plt

img = plt.imread('./results/horse2zebra_pretrained/test_latest/images/n02381460_1010_real.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7f89826bced0>

image.png


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。