pix2pix图像风格迁移
pix2pix
论文链接: https://arxiv.org/abs/1611.07004
图像处理的很多问题都是将一张输入的图片转变为一张对应的输出图片,比如灰度图、梯度图、彩色图之间的转换等。通常每一种问题都使用特定的算法(如:使用CNN来解决图像转换问题时,要根据每个问题设定一个特定的loss function 来让CNN去优化,而一般的方法都是训练CNN去缩小输入跟输出的欧氏距离,但这样通常会得到比较模糊的输出)。这些方法的本质其实都是从像素到像素的映射。于是论文在GAN的基础上提出一个通用的方法:pix2pix 来解决这一类问题。通过pix2pix来完成成对的图像转换(Labels to Street Scene, Aerial to Map,Day to Night等),可以得到比较清晰的结果。
注意事项:
使用框架**:** PyTorch1.4.0
使用硬件**:** 8 vCPU + 64 GiB + 1 x Tesla V100-PCIE-32GB
运行代码方法**:** 点击本页面顶部菜单栏的三角形运行按钮或按Ctrl+Enter键 运行每个方块中的代码
JupyterLab的详细用法**:** 请参考《ModelAtrs JupyterLab使用指导》
碰到问题的解决办法**:** 请参考《ModelAtrs JupyterLab常见问题解决办法》
1.下载代码和数据集
运行下面代码,进行数据和代码的下载和解压缩
使用facades数据集,数据位于pix2pix/datasets/facades/中
import os
# 数据代码下载
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/pix2pix.zip
# 解压缩
os.system('unzip pix2pix.zip -d ./')
os.chdir('./pix2pix')
--2021-07-14 14:13:43-- https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/algorithm/pix2pix.zip
Resolving proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)... 192.168.0.172
Connecting to proxy-notebook.modelarts-dev-proxy.com (proxy-notebook.modelarts-dev-proxy.com)|192.168.0.172|:8083... connected.
Proxy request sent, awaiting response... 200 OK
Length: 908043347 (866M) [application/zip]
Saving to: ‘pix2pix.zip’
pix2pix.zip 100%[===================>] 865.98M 216MB/s in 4.1s
2021-07-14 14:13:49 (209 MB/s) - ‘pix2pix.zip’ saved [908043347/908043347]
2.训练
2.1安装依赖库
!pip install -r requirements.txt
Looking in indexes: http://repo.myhuaweicloud.com/repository/pypi/simple
Requirement already satisfied: torch>=1.4.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 1)) (1.4.0)
Requirement already satisfied: torchvision>=0.5.0 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from -r requirements.txt (line 2)) (0.5.0)
Collecting dominate>=2.4.0
Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/ef/a8/4354f8122c39e35516a2708746d89db5e339c867abbd8e0179bccee4b7f9/dominate-2.6.0-py2.py3-none-any.whl (29 kB)
Requirement already satisfied: pillow>=4.1.1 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (6.2.0)
Requirement already satisfied: numpy in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.20.2)
Requirement already satisfied: six in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from torchvision>=0.5.0->-r requirements.txt (line 2)) (1.15.0)
Collecting visdom>=0.1.8.8
Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/c9/75/e078f5a2e1df7e0d3044749089fc2823e62d029cc027ed8ae5d71fafcbdc/visdom-0.1.8.9.tar.gz (676 kB)
[K |████████████████████████████████| 676 kB 4.1 MB/s eta 0:00:01
[?25hRequirement already satisfied: scipy in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.3.2)
Requirement already satisfied: requests in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.25.1)
Requirement already satisfied: tornado in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (6.1)
Requirement already satisfied: pyzmq in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from visdom>=0.1.8.8->-r requirements.txt (line 4)) (22.0.3)
Collecting jsonpatch
Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/a3/55/f7c93bae36d869292aedfbcbae8b091386194874f16390d680136edd2b28/jsonpatch-1.32-py2.py3-none-any.whl (12 kB)
Collecting jsonpointer>=1.9
Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/23/52/05f67532aa922e494c351344e0d9624a01f74f5dd8402fe0d1b563a6e6fc/jsonpointer-2.1-py2.py3-none-any.whl (7.4 kB)
Requirement already satisfied: urllib3<1.27,>=1.21.1 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (1.26.4)
Requirement already satisfied: certifi>=2017.4.17 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2020.12.5)
Requirement already satisfied: idna<3,>=2.5 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (2.10)
Requirement already satisfied: chardet<5,>=3.0.2 in /home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages (from requests->visdom>=0.1.8.8->-r requirements.txt (line 4)) (4.0.0)
Collecting torchfile
Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/91/af/5b305f86f2d218091af657ddb53f984ecbd9518ca9fe8ef4103a007252c9/torchfile-0.1.0.tar.gz (5.2 kB)
Collecting websocket-client
Downloading http://repo.myhuaweicloud.com/repository/pypi/packages/ca/5f/3c211d168b2e9f9342cfb53bcfc26aab0eac63b998015e7af7bcae66119d/websocket_client-1.1.0-py2.py3-none-any.whl (68 kB)
[K |████████████████████████████████| 68 kB 97.9 MB/s eta 0:00:01
[?25hBuilding wheels for collected packages: visdom, torchfile
Building wheel for visdom (setup.py) ... [?25ldone
[?25h Created wheel for visdom: filename=visdom-0.1.8.9-py3-none-any.whl size=655249 sha256=af2915fa47add62b99323013ddead7e208dcd7e72337fde88c4c1d3ef4db4c2e
Stored in directory: /home/ma-user/.cache/pip/wheels/c2/4f/64/6370a1da43381982410b766af6c395acfe497647ce282a6ce3
Building wheel for torchfile (setup.py) ... [?25ldone
[?25h Created wheel for torchfile: filename=torchfile-0.1.0-py3-none-any.whl size=5711 sha256=2fb039c294bed9028ef6e933b120b4af720ef9cf992726be8cc4958a7b5d07f0
Stored in directory: /home/ma-user/.cache/pip/wheels/c2/b4/9a/0740ff703d3f6b175a0b7c34cbd592ee21fae893e06f37d70c
Successfully built visdom torchfile
Installing collected packages: jsonpointer, websocket-client, torchfile, jsonpatch, visdom, dominate
Successfully installed dominate-2.6.0 jsonpatch-1.32 jsonpointer-2.1 torchfile-0.1.0 visdom-0.1.8.9 websocket-client-1.1.0
[33mWARNING: You are using pip version 20.3.3; however, version 21.1.3 is available.
You should consider upgrading via the '/home/ma-user/anaconda3/envs/PyTorch-1.4/bin/python -m pip install --upgrade pip' command.[0m
2.2开始训练
训练参数可以在pix2pix/options/train_options.py中查看和修改
如果使用其他数据集,需要修改数据路径
模型命名为facades_pix2pix
!python train.py --dataroot ./datasets/facades --name facades_pix2pix --model pix2pix --direction BtoA
----------------- Options ---------------
batch_size: 1
beta1: 0.5
checkpoints_dir: ./checkpoints
continue_train: False
crop_size: 256
dataroot: ./datasets/facades [default: None]
dataset_mode: aligned
direction: BtoA [default: AtoB]
display_env: main
display_freq: 400
display_id: 1
display_ncols: 4
display_port: 8097
display_server: http://localhost
display_winsize: 256
epoch: latest
epoch_count: 1
gan_mode: vanilla
gpu_ids: 0
init_gain: 0.02
init_type: normal
input_nc: 3
isTrain: True [default: None]
lambda_L1: 100.0
load_iter: 0 [default: 0]
load_size: 286
lr: 0.0002
lr_decay_iters: 50
lr_policy: linear
max_dataset_size: inf
model: pix2pix
n_epochs: 5
n_epochs_decay: 5
n_layers_D: 3
name: facades_pix2pix [default: experiment_name]
ndf: 64
netD: basic
netG: unet_256
ngf: 64
no_dropout: False
no_flip: False
no_html: False
norm: batch
num_threads: 4
output_nc: 3
phase: train
pool_size: 0
preprocess: resize_and_crop
print_freq: 100
save_by_iter: False
save_epoch_freq: 5
save_latest_freq: 5000
serial_batches: False
suffix:
update_html_freq: 1000
verbose: False
----------------- End -------------------
dataset [AlignedDataset] was created
The number of training images = 405
initialize network with normal
initialize network with normal
model [Pix2PixModel] was created
---------- Networks initialized -------------
[Network G] Total number of parameters : 54.414 M
[Network D] Total number of parameters : 2.769 M
-----------------------------------------------
/home/ma-user/anaconda3/envs/PyTorch-1.4/lib/python3.7/site-packages/torch/optim/lr_scheduler.py:122: UserWarning: Detected call of `lr_scheduler.step()` before `optimizer.step()`. In PyTorch 1.1.0 and later, you should call them in the opposite order: `optimizer.step()` before `lr_scheduler.step()`. Failure to do this will result in PyTorch skipping the first value of the learning rate schedule. See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate
"https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate", UserWarning)
learning rate 0.0002000 -> 0.0002000
(epoch: 1, iters: 100, time: 0.041, data: 0.195)
(epoch: 1, iters: 200, time: 0.041, data: 0.001)
(epoch: 1, iters: 300, time: 0.039, data: 0.001)
(epoch: 1, iters: 400, time: 0.041, data: 0.001)
End of epoch 1 / 10 Time Taken: 13 sec
learning rate 0.0002000 -> 0.0002000
(epoch: 2, iters: 95, time: 0.042, data: 0.001)
(epoch: 2, iters: 195, time: 0.042, data: 0.001)
(epoch: 2, iters: 295, time: 0.042, data: 0.001)
(epoch: 2, iters: 395, time: 0.042, data: 0.001)
End of epoch 2 / 10 Time Taken: 12 sec
learning rate 0.0002000 -> 0.0002000
(epoch: 3, iters: 90, time: 0.039, data: 0.001)
(epoch: 3, iters: 190, time: 0.041, data: 0.001)
(epoch: 3, iters: 290, time: 0.042, data: 0.001)
(epoch: 3, iters: 390, time: 0.042, data: 0.001)
End of epoch 3 / 10 Time Taken: 12 sec
learning rate 0.0002000 -> 0.0002000
(epoch: 4, iters: 85, time: 0.041, data: 0.001)
(epoch: 4, iters: 185, time: 0.042, data: 0.001)
(epoch: 4, iters: 285, time: 0.042, data: 0.001)
(epoch: 4, iters: 385, time: 0.041, data: 0.001)
End of epoch 4 / 10 Time Taken: 12 sec
learning rate 0.0002000 -> 0.0001667
(epoch: 5, iters: 80, time: 0.041, data: 0.001)
(epoch: 5, iters: 180, time: 0.042, data: 0.001)
(epoch: 5, iters: 280, time: 0.042, data: 0.001)
(epoch: 5, iters: 380, time: 0.041, data: 0.001)
saving the model at the end of epoch 5, iters 2025
End of epoch 5 / 10 Time Taken: 14 sec
learning rate 0.0001667 -> 0.0001333
(epoch: 6, iters: 75, time: 0.041, data: 0.001)
(epoch: 6, iters: 175, time: 0.042, data: 0.001)
(epoch: 6, iters: 275, time: 0.041, data: 0.001)
(epoch: 6, iters: 375, time: 0.041, data: 0.001)
End of epoch 6 / 10 Time Taken: 12 sec
learning rate 0.0001333 -> 0.0001000
(epoch: 7, iters: 70, time: 0.041, data: 0.001)
(epoch: 7, iters: 170, time: 0.042, data: 0.001)
(epoch: 7, iters: 270, time: 0.041, data: 0.001)
(epoch: 7, iters: 370, time: 0.041, data: 0.001)
End of epoch 7 / 10 Time Taken: 12 sec
learning rate 0.0001000 -> 0.0000667
(epoch: 8, iters: 65, time: 0.041, data: 0.001)
(epoch: 8, iters: 165, time: 0.041, data: 0.001)
(epoch: 8, iters: 265, time: 0.041, data: 0.001)
(epoch: 8, iters: 365, time: 0.041, data: 0.001)
End of epoch 8 / 10 Time Taken: 12 sec
learning rate 0.0000667 -> 0.0000333
(epoch: 9, iters: 60, time: 0.041, data: 0.001)
(epoch: 9, iters: 160, time: 0.041, data: 0.001)
(epoch: 9, iters: 260, time: 0.041, data: 0.001)
(epoch: 9, iters: 360, time: 0.039, data: 0.001)
End of epoch 9 / 10 Time Taken: 12 sec
learning rate 0.0000333 -> 0.0000000
(epoch: 10, iters: 55, time: 0.041, data: 0.002)
(epoch: 10, iters: 155, time: 0.042, data: 0.001)
(epoch: 10, iters: 255, time: 0.041, data: 0.001)
(epoch: 10, iters: 355, time: 0.040, data: 0.001)
saving the model at the end of epoch 10, iters 4050
End of epoch 10 / 10 Time Taken: 14 sec
3.测试
查看刚才生成的模型facades_pix2pix是否已经生成,如果生成则会在checkpoints文件下
!ls checkpoints/
facades_label2photo_pretrained facades_pix2pix
用训练生成的模型facades_pix2pix进行测试
!python test.py --dataroot ./datasets/facades --direction BtoA --model pix2pix --name facades_pix2pix
----------------- Options ---------------
aspect_ratio: 1.0
batch_size: 1
checkpoints_dir: ./checkpoints
crop_size: 256
dataroot: ./datasets/facades [default: None]
dataset_mode: aligned
direction: BtoA [default: AtoB]
display_winsize: 256
epoch: latest
eval: False
gpu_ids: 0
init_gain: 0.02
init_type: normal
input_nc: 3
isTrain: False [default: None]
load_iter: 0 [default: 0]
load_size: 256
max_dataset_size: inf
model: pix2pix [default: test]
n_layers_D: 3
name: facades_pix2pix [default: experiment_name]
ndf: 64
netD: basic
netG: unet_256
ngf: 64
no_dropout: False
no_flip: False
norm: batch
num_test: 50
num_threads: 4
output_nc: 3
phase: test
preprocess: resize_and_crop
results_dir: ./results/
serial_batches: False
suffix:
verbose: False
----------------- End -------------------
dataset [AlignedDataset] was created
initialize network with normal
model [Pix2PixModel] was created
loading the model from ./checkpoints/facades_pix2pix/latest_net_G.pth
---------- Networks initialized -------------
[Network G] Total number of parameters : 54.414 M
-----------------------------------------------
creating web directory ./results/facades_pix2pix/test_latest
processing (0000)-th image... ['./datasets/facades/test/1.jpg']
processing (0005)-th image... ['./datasets/facades/test/103.jpg']
processing (0010)-th image... ['./datasets/facades/test/12.jpg']
processing (0015)-th image... ['./datasets/facades/test/17.jpg']
processing (0020)-th image... ['./datasets/facades/test/21.jpg']
processing (0025)-th image... ['./datasets/facades/test/26.jpg']
processing (0030)-th image... ['./datasets/facades/test/30.jpg']
processing (0035)-th image... ['./datasets/facades/test/35.jpg']
processing (0040)-th image... ['./datasets/facades/test/4.jpg']
processing (0045)-th image... ['./datasets/facades/test/44.jpg']
展示测试结果
可以在./results/facades_pix2pix/文件下看到测试生成的结果
import matplotlib.pyplot as plt
img = plt.imread('./results/facades_pix2pix/test_latest/images/100_fake_B.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7ff680a8cd50>
img = plt.imread('./results/facades_pix2pix/test_latest/images/100_real_A.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7ff680524090>
img = plt.imread('./results/facades_pix2pix/test_latest/images/100_real_B.png')
plt.imshow(img)
<matplotlib.image.AxesImage at 0x7ff6613ceb90>
- 点赞
- 收藏
- 关注作者
评论(0)