pix2pix图像风格迁移

举报
HWCloudAI 发表于 2022/12/01 10:54:10 2022/12/01
【摘要】 pix2pix论文链接: https://arxiv.org/abs/1611.07004图像处理的很多问题都是将一张输入的图片转变为一张对应的输出图片,比如灰度图、梯度图、彩色图之间的转换等。通常每一种问题都使用特定的算法(如:使用CNN来解决图像转换问题时,要根据每个问题设定一个特定的loss function 来让CNN去优化,而一般的方法都是训练CNN去缩小输入跟输出的欧氏距离,但...

pix2pix

论文链接: https://arxiv.org/abs/1611.07004

图像处理的很多问题都是将一张输入的图片转变为一张对应的输出图片,比如灰度图、梯度图、彩色图之间的转换等。通常每一种问题都使用特定的算法(如:使用CNN来解决图像转换问题时,要根据每个问题设定一个特定的loss function 来让CNN去优化,而一般的方法都是训练CNN去缩小输入跟输出的欧氏距离,但这样通常会得到比较模糊的输出)。这些方法的本质其实都是从像素到像素的映射。于是论文在GAN的基础上提出一个通用的方法:pix2pix 来解决这一类问题。通过pix2pix来完成成对的图像转换(Labels to Street Scene, Aerial to Map,Day to Night等),可以得到比较清晰的结果。

注意事项:

使用框架**:** PyTorch1.4.0

使用硬件**:** 8 vCPU + 64 GiB + 1 x Tesla V100-PCIE-32GB

运行代码方法**:** 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码

JupyterLab的详细用法**:** 请参考《ModelAtrs JupyterLab使用指导》

碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》

1.下载代码和数据集

运行下面代码,进行数据和代码的下载和解压缩

使用facades数据集,数据位于pix2pix/datasets/facades/中

import os
# 数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/pix2pix.zip
# 解压缩
os.system('unzip pix2pix.zip -d ./')
os.chdir('./pix2pix')
--2021-07-14 14:13:43--  https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/pix2pix.zip

Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172

Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.

Proxy request sent, awaiting response... 200 OK

Length: 908043347 (866M) [application/zip]

Saving to: ‘pix2pix.zip’



pix2pix.zip         100%[===================>] 865.98M   216MB/s    in 4.1s    



2021-07-14 14:13:49 (209 MB/s) - ‘pix2pix.zip’ saved [908043347/908043347]

2.训练

2.1安装依赖库

!pip install -r requirements.txt
Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple

Requirement already satisfied: torch>=1.4.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.4.0)

Requirement already satisfied: torchvision>=0.5.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.5.0)

Collecting dominate>=2.4.0

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/ef/a8/4354f8122c39e35516a2708746d89db5e339c867abbd8e0179bccee4b7f9/dominate-2.6.0-py2.py3-none-any.whl (29 kB)

Requirement already satisfied: pillow>=4.1.1 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (6.2.0)

Requirement already satisfied: numpy in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.20.2)

Requirement already satisfied: six in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.15.0)

Collecting visdom>=0.1.8.8

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/c9/75/e078f5a2e1df7e0d3044749089fc2823e62d029cc027ed8ae5d71fafcbdc/visdom-0.1.8.9.tar.gz (676 kB)

     |████████████████████████████████| 676 kB 4.1 MB/s eta 0:00:01

[?25hRequirement already satisfied: scipy in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.3.2)

Requirement already satisfied: requests in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.25.1)

Requirement already satisfied: tornado in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (6.1)

Requirement already satisfied: pyzmq in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (22.0.3)

Collecting jsonpatch

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/a3/55/f7c93bae36d869292aedfbcbae8b091386194874f16390d680136edd2b28/jsonpatch-1.32-py2.py3-none-any.whl (12 kB)

Collecting jsonpointer>=1.9

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/23/52/05f67532aa922e494c351344e0d9624a01f74f5dd8402fe0d1b563a6e6fc/jsonpointer-2.1-py2.py3-none-any.whl (7.4 kB)

Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.26.4)

Requirement already satisfied: certifi>=2017.4.17 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2020.12.5)

Requirement already satisfied: idna<3,>=2.5 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.10)

Requirement already satisfied: chardet<5,>=3.0.2 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (4.0.0)

Collecting torchfile

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/91/af/5b305f86f2d218091af657ddb53f984ecbd9518ca9fe8ef4103a007252c9/torchfile-0.1.0.tar.gz (5.2 kB)

Collecting websocket-client

  Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/ca/5f/3c211d168b2e9f9342cfb53bcfc26aab0eac63b998015e7af7bcae66119d/websocket_client-1.1.0-py2.py3-none-any.whl (68 kB)

     |████████████████████████████████| 68 kB 97.9 MB/s  eta 0:00:01

[?25hBuilding wheels for collected packages: visdom, torchfile

  Building wheel for visdom (setup.py) ... [?25ldone

[?25h  Created wheel for visdom: filename=visdom-0.1.8.9-py3-none-any.whl size=655249 sha256=af2915fa47add62b99323013ddead7e208dcd7e72337fde88c4c1d3ef4db4c2e

  Stored in directory: /home/ma-user/.cache/pip/wheels/c2/4f/64/6370a1da43381982410b766af6c395acfe497647ce282a6ce3

  Building wheel for torchfile (setup.py) ... [?25ldone

[?25h  Created wheel for torchfile: filename=torchfile-0.1.0-py3-none-any.whl size=5711 sha256=2fb039c294bed9028ef6e933b120b4af720ef9cf992726be8cc4958a7b5d07f0

  Stored in directory: /home/ma-user/.cache/pip/wheels/c2/b4/9a/0740ff703d3f6b175a0b7c34cbd592ee21fae893e06f37d70c

Successfully built visdom torchfile

Installing collected packages: jsonpointer, websocket-client, torchfile, jsonpatch, visdom, dominate

Successfully installed dominate-2.6.0 jsonpatch-1.32 jsonpointer-2.1 torchfile-0.1.0 visdom-0.1.8.9 websocket-client-1.1.0

WARNING: You are using pip version 20.3.3; however, version 21.1.3 is available.

You should consider upgrading via the '/home/ma-user/anaconda3/envs/PyTorch-1.4/bin/python -m pip install --upgrade pip' command.

2.2开始训练

训练参数可以在pix2pix/options/train_options.py中查看和修改
如果使用其他数据集,需要修改数据路径
模型命名为facades_pix2pix

!python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
----------------- Options ---------------

               batch_size: 1                             

                    beta1: 0.5                           

          checkpoints_dir: ./checkpoints                 

           continue_train: False                         

                crop_size: 256                           

                 dataroot: ./datasets/facades            	[default: None]

             dataset_mode: aligned                       

                direction: BtoA                          	[default: AtoB]

              display_env: main                          

             display_freq: 400                           

               display_id: 1                             

            display_ncols: 4                             

             display_port: 8097                          

           display_server: http://localhost              

          display_winsize: 256                           

                    epoch: latest                        

              epoch_count: 1                             

                 gan_mode: vanilla                       

                  gpu_ids: 0                             

                init_gain: 0.02                          

                init_type: normal                        

                 input_nc: 3                             

                  isTrain: True                          	[default: None]

                lambda_L1: 100.0                         

                load_iter: 0                             	[default: 0]

                load_size: 286                           

                       lr: 0.0002                        

           lr_decay_iters: 50                            

                lr_policy: linear                        

         max_dataset_size: inf                           

                    model: pix2pix                       

                 n_epochs: 5                             

           n_epochs_decay: 5                             

               n_layers_D: 3                             

                     name: facades_pix2pix               	[default: experiment_name]

                      ndf: 64                            

                     netD: basic                         

                     netG: unet_256                      

                      ngf: 64                            

               no_dropout: False                         

                  no_flip: False                         

                  no_html: False                         

                     norm: batch                         

              num_threads: 4                             

                output_nc: 3                             

                    phase: train                         

                pool_size: 0                             

               preprocess: resize_and_crop               

               print_freq: 100                           

             save_by_iter: False                         

          save_epoch_freq: 5                             

         save_latest_freq: 5000                          

           serial_batches: False                         

                   suffix:                               

         update_html_freq: 1000                          

                  verbose: False                         

----------------- End -------------------

dataset [AlignedDataset] was created

The number of training images = 405

initialize network with normal

initialize network with normal

model [Pix2PixModel] was created

---------- Networks initialized -------------

[Network G] Total number of parameters : 54.414 M

[Network D] Total number of parameters : 2.769 M

-----------------------------------------------

/home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:122: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`.  Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate

  "https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)

learning rate 0.0002000 -> 0.0002000

(epoch: 1, iters: 100, time: 0.041, data: 0.195) 

(epoch: 1, iters: 200, time: 0.041, data: 0.001) 

(epoch: 1, iters: 300, time: 0.039, data: 0.001) 

(epoch: 1, iters: 400, time: 0.041, data: 0.001) 

End of epoch 1 / 10 	 Time Taken: 13 sec

learning rate 0.0002000 -> 0.0002000

(epoch: 2, iters: 95, time: 0.042, data: 0.001) 

(epoch: 2, iters: 195, time: 0.042, data: 0.001) 

(epoch: 2, iters: 295, time: 0.042, data: 0.001) 

(epoch: 2, iters: 395, time: 0.042, data: 0.001) 

End of epoch 2 / 10 	 Time Taken: 12 sec

learning rate 0.0002000 -> 0.0002000

(epoch: 3, iters: 90, time: 0.039, data: 0.001) 

(epoch: 3, iters: 190, time: 0.041, data: 0.001) 

(epoch: 3, iters: 290, time: 0.042, data: 0.001) 

(epoch: 3, iters: 390, time: 0.042, data: 0.001) 

End of epoch 3 / 10 	 Time Taken: 12 sec

learning rate 0.0002000 -> 0.0002000

(epoch: 4, iters: 85, time: 0.041, data: 0.001) 

(epoch: 4, iters: 185, time: 0.042, data: 0.001) 

(epoch: 4, iters: 285, time: 0.042, data: 0.001) 

(epoch: 4, iters: 385, time: 0.041, data: 0.001) 

End of epoch 4 / 10 	 Time Taken: 12 sec

learning rate 0.0002000 -> 0.0001667

(epoch: 5, iters: 80, time: 0.041, data: 0.001) 

(epoch: 5, iters: 180, time: 0.042, data: 0.001) 

(epoch: 5, iters: 280, time: 0.042, data: 0.001) 

(epoch: 5, iters: 380, time: 0.041, data: 0.001) 

saving the model at the end of epoch 5, iters 2025

End of epoch 5 / 10 	 Time Taken: 14 sec

learning rate 0.0001667 -> 0.0001333

(epoch: 6, iters: 75, time: 0.041, data: 0.001) 

(epoch: 6, iters: 175, time: 0.042, data: 0.001) 

(epoch: 6, iters: 275, time: 0.041, data: 0.001) 

(epoch: 6, iters: 375, time: 0.041, data: 0.001) 

End of epoch 6 / 10 	 Time Taken: 12 sec

learning rate 0.0001333 -> 0.0001000

(epoch: 7, iters: 70, time: 0.041, data: 0.001) 

(epoch: 7, iters: 170, time: 0.042, data: 0.001) 

(epoch: 7, iters: 270, time: 0.041, data: 0.001) 

(epoch: 7, iters: 370, time: 0.041, data: 0.001) 

End of epoch 7 / 10 	 Time Taken: 12 sec

learning rate 0.0001000 -> 0.0000667

(epoch: 8, iters: 65, time: 0.041, data: 0.001) 

(epoch: 8, iters: 165, time: 0.041, data: 0.001) 

(epoch: 8, iters: 265, time: 0.041, data: 0.001) 

(epoch: 8, iters: 365, time: 0.041, data: 0.001) 

End of epoch 8 / 10 	 Time Taken: 12 sec

learning rate 0.0000667 -> 0.0000333

(epoch: 9, iters: 60, time: 0.041, data: 0.001) 

(epoch: 9, iters: 160, time: 0.041, data: 0.001) 

(epoch: 9, iters: 260, time: 0.041, data: 0.001) 

(epoch: 9, iters: 360, time: 0.039, data: 0.001) 

End of epoch 9 / 10 	 Time Taken: 12 sec

learning rate 0.0000333 -> 0.0000000

(epoch: 10, iters: 55, time: 0.041, data: 0.002) 

(epoch: 10, iters: 155, time: 0.042, data: 0.001) 

(epoch: 10, iters: 255, time: 0.041, data: 0.001) 

(epoch: 10, iters: 355, time: 0.040, data: 0.001) 

saving the model at the end of epoch 10, iters 4050

End of epoch 10 / 10 	 Time Taken: 14 sec

3.测试

查看刚才生成的模型facades_pix2pix是否已经生成,如果生成则会在checkpoints文件下

!ls checkpoints/
facades_label2photo_pretrained	facades_pix2pix

用训练生成的模型facades_pix2pix进行测试

!python test.py --dataroot ./datasets/facades --direction BtoA --model pix2pix --name facades_pix2pix
----------------- Options ---------------

             aspect_ratio: 1.0                           

               batch_size: 1                             

          checkpoints_dir: ./checkpoints                 

                crop_size: 256                           

                 dataroot: ./datasets/facades            	[default: None]

             dataset_mode: aligned                       

                direction: BtoA                          	[default: AtoB]

          display_winsize: 256                           

                    epoch: latest                        

                     eval: False                         

                  gpu_ids: 0                             

                init_gain: 0.02                          

                init_type: normal                        

                 input_nc: 3                             

                  isTrain: False                         	[default: None]

                load_iter: 0                             	[default: 0]

                load_size: 256                           

         max_dataset_size: inf                           

                    model: pix2pix                       	[default: test]

               n_layers_D: 3                             

                     name: facades_pix2pix               	[default: experiment_name]

                      ndf: 64                            

                     netD: basic                         

                     netG: unet_256                      

                      ngf: 64                            

               no_dropout: False                         

                  no_flip: False                         

                     norm: batch                         

                 num_test: 50                            

              num_threads: 4                             

                output_nc: 3                             

                    phase: test                          

               preprocess: resize_and_crop               

              results_dir: ./results/                    

           serial_batches: False                         

                   suffix:                               

                  verbose: False                         

----------------- End -------------------

dataset [AlignedDataset] was created

initialize network with normal

model [Pix2PixModel] was created

loading the model from ./checkpoints/facades_pix2pix/latest_net_G.pth

---------- Networks initialized -------------

[Network G] Total number of parameters : 54.414 M

-----------------------------------------------

creating web directory ./results/facades_pix2pix/test_latest

processing (0000)-th image... ['./datasets/facades/test/1.jpg']

processing (0005)-th image... ['./datasets/facades/test/103.jpg']

processing (0010)-th image... ['./datasets/facades/test/12.jpg']

processing (0015)-th image... ['./datasets/facades/test/17.jpg']

processing (0020)-th image... ['./datasets/facades/test/21.jpg']

processing (0025)-th image... ['./datasets/facades/test/26.jpg']

processing (0030)-th image... ['./datasets/facades/test/30.jpg']

processing (0035)-th image... ['./datasets/facades/test/35.jpg']

processing (0040)-th image... ['./datasets/facades/test/4.jpg']

processing (0045)-th image... ['./datasets/facades/test/44.jpg']

展示测试结果

可以在./results/facades_pix2pix/文件下看到测试生成的结果

import matplotlib.pyplot as plt

img = plt.imread('./results/facades_pix2pix/test_latest/images/100_fake_B.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7ff680a8cd50>

image.png

img = plt.imread('./results/facades_pix2pix/test_latest/images/100_real_A.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7ff680524090>

image.png

img = plt.imread('./results/facades_pix2pix/test_latest/images/100_real_B.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7ff6613ceb90>

image.png

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。