【图像翻译】Pix2Pix——基于GAN的图像风格迁移模型【论文解读+代码实验】
【图像翻译】Pix2Pix——基于GAN的图像风格迁移模型【论文解读+代码实验】
CVPR 2017论文pix2pix,在image2image的任务之中具有很好的效果。
Pix2Pix 基于 GAN 架构,利用成对的图片进行图像翻译,即输入为同一张图片的两种不同风格,可用于进行风格迁移。
标签地图——照片
重建的图像——边缘图像
黑白图像——彩色图像
传统GAN原理
Pix2Pix 的结构
一个 GAN 结构的网络至少由两部分构成:生成器模型(Generative Model)与判别器模型(Discriminative Model)。GAN 通过两个模块的互相博弈学习产生相当好的输出。一个优秀的 GAN 需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。
- 生成器(Generator)
Pix2Pix 的生成器采用了一个U-Net架构。U-Net 是一种经典的编码器-解码器结构,具有跳跃连接(skip connections),以保留高分辨率的局部特征,增强图像生成的质量。具体来说:
- 编码器:一系列卷积层将输入图像编码为低分辨率的特征表示。
- 解码器:一系列反卷积层将特征表示解码回原始图像的分辨率。
- 跳跃连接:编码器和解码器的对应层之间通过直接连接传递特征,保留了输入图像的细节信息。
这使得生成器能够很好地处理结构化的图像转换任务,例如从边缘生成图像或从黑白照片生成彩色照片。
生成器优化目标
优化目标的定义,包括2部分,一部分是真实样本(x,y),第二部分x和生成的G(x,z)。组成一个标准的GAN损失。
- 判别器(Discriminator)
Pix2Pix 的判别器采用了PatchGAN架构,而不是传统的全局判别器。
- PatchGAN 判别器:对输入图像的局部小块(Patch)进行判别,而不是对整张图像进行判别。这种设计的好处是可以关注到细节纹理和局部一致性。
- 判别器的目标是区分生成的图像与真实图像是否符合条件输入。
这种设计使 Pix2Pix 更加适用于像素级的图像对齐任务,因为 PatchGAN 只关注局部的图像质量,而不是全局的风格或分布。
PatchGAN其实指的是GAN的判别器,将判别器换成了全卷积网络.
复现练习
数据集
数据集采用的是facades数据集,训练样本是成对匹配的(其中fecades图片均以png格式存储,对应真实样本以jpg存储),共有606张图片,训练过程中将数据集按8:2的比例划分训练集和验证集。
代码
import torch.nn as nn
from torchsummary import summary
import torch
from collections import OrderedDict
# 定义降采样部分
class downsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(downsample, self).__init__()
self.down = nn.Sequential(
nn.LeakyReLU(0.2, True),
nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels)
)
def forward(self, x):
return self.down(x)
# 定义上采样部分
class upsample(nn.Module):
def __init__(self, in_channels, out_channels, drop_out=False):
super(upsample, self).__init__()
self.up = nn.Sequential(
nn.ReLU(True),
nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1),
nn.BatchNorm2d(out_channels),
nn.Dropout(0.5) if drop_out else nn.Identity()
)
def forward(self, x):
return self.up(x)
# ---------------------------------------------------------------------------------
# 定义pix_G =>input 128*128
class pix2pixG_128(nn.Module):
def __init__(self):
super(pix2pixG_128, self).__init__()
# down sample
self.down_1 = nn.Conv2d(3, 64, 4, 2, 1) # [batch,3,128,128]=>[batch,64,64,64]
for i in range(7):
if i == 0:
self.down_2 = downsample(64, 128) # [batch,64,64,64]=>[batch,128,32,32]
self.down_3 = downsample(128, 256) # [batch,128,32,32]=>[batch,256,16,16]
self.down_4 = downsample(256, 512) # [batch,256,16,16]=>[batch,512,8,8]
self.down_5 = downsample(512, 512) # [batch,512,8,8]=>[batch,512,4,4]
self.down_6 = downsample(512, 512) # [batch,512,4,4]=>[batch,512,2,2]
self.down_7 = downsample(512, 512) # [batch,512,2,2]=>[batch,512,1,1]
# up_sample
for i in range(7):
if i == 0:
self.up_1 = upsample(512, 512) # [batch,512,1,1]=>[batch,512,2,2]
self.up_2 = upsample(1024, 512, drop_out=True) # [batch,1024,2,2]=>[batch,512,4,4]
self.up_3 = upsample(1024, 512, drop_out=True) # [batch,1024,4,4]=>[batch,512,8,8]
self.up_4 = upsample(1024, 256) # [batch,1024,8,8]=>[batch,256,16,16]
self.up_5 = upsample(512, 128) # [batch,512,16,16]=>[batch,128,32,32]
self.up_6 = upsample(256, 64) # [batch,256,32,32]=>[batch,64,64,64]
self.last_Conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
self.init_weight()
def init_weight(self):
for w in self.modules():
if isinstance(w, nn.Conv2d):
nn.init.kaiming_normal_(w.weight, mode='fan_out')
if w.bias is not None:
nn.init.zeros_(w.bias)
elif isinstance(w, nn.ConvTranspose2d):
nn.init.kaiming_normal_(w.weight, mode='fan_in')
elif isinstance(w, nn.BatchNorm2d):
nn.init.ones_(w.weight)
nn.init.zeros_(w.bias)
def forward(self, x):
# down
down_1 = self.down_1(x)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_6 = self.down_6(down_5)
down_7 = self.down_7(down_6)
# up
up_1 = self.up_1(down_7)
up_2 = self.up_2(torch.cat([up_1, down_6], dim=1))
up_3 = self.up_3(torch.cat([up_2, down_5], dim=1))
up_4 = self.up_4(torch.cat([up_3, down_4], dim=1))
up_5 = self.up_5(torch.cat([up_4, down_3], dim=1))
up_6 = self.up_6(torch.cat([up_5, down_2], dim=1))
out = self.last_Conv(torch.cat([up_6, down_1], dim=1))
return out
# 定义pix_D_128 => input 128*128
class pix2pixD_128(nn.Module):
def __init__(self):
super(pix2pixD_128, self).__init__()
# 定义基本的卷积\bn\relu
def base_Conv_bn_lkrl(in_channels, out_channels, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
D_dic = OrderedDict()
in_channels = 6
out_channels = 64
for i in range(4):
if i < 3:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
else:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
in_channels = out_channels
out_channels *= 2
D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)}) # [batch,1,14,14]
self.D_model = nn.Sequential(D_dic)
def forward(self, x1, x2):
in_x = torch.cat([x1, x2], dim=1)
return self.D_model(in_x)
# ---------------------------------------------------------------------------------
# 256*256
class pix2pixG_256(nn.Module):
def __init__(self):
super(pix2pixG_256, self).__init__()
# down sample
self.down_1 = nn.Conv2d(3, 64, 4, 2, 1) # [batch,3,256,256]=>[batch,64,128,128]
for i in range(7):
if i == 0:
self.down_2 = downsample(64, 128) # [batch,64,128,128]=>[batch,128,64,64]
self.down_3 = downsample(128, 256) # [batch,128,64,64]=>[batch,256,32,32]
self.down_4 = downsample(256, 512) # [batch,256,32,32]=>[batch,512,16,16]
self.down_5 = downsample(512, 512) # [batch,512,16,16]=>[batch,512,8,8]
self.down_6 = downsample(512, 512) # [batch,512,8,8]=>[batch,512,4,4]
self.down_7 = downsample(512, 512) # [batch,512,4,4]=>[batch,512,2,2]
self.down_8 = downsample(512, 512) # [batch,512,2,2]=>[batch,512,1,1]
# up_sample
for i in range(7):
if i == 0:
self.up_1 = upsample(512, 512) # [batch,512,1,1]=>[batch,512,2,2]
self.up_2 = upsample(1024, 512, drop_out=True) # [batch,1024,2,2]=>[batch,512,4,4]
self.up_3 = upsample(1024, 512, drop_out=True) # [batch,1024,4,4]=>[batch,512,8,8]
self.up_4 = upsample(1024, 512) # [batch,1024,8,8]=>[batch,512,16,16]
self.up_5 = upsample(1024, 256) # [batch,1024,16,16]=>[batch,256,32,32]
self.up_6 = upsample(512, 128) # [batch,512,32,32]=>[batch,128,64,64]
self.up_7 = upsample(256, 64) # [batch,256,64,64]=>[batch,64,128,128]
self.last_Conv = nn.Sequential(
nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
nn.Tanh()
)
self.init_weight()
def init_weight(self):
for w in self.modules():
if isinstance(w, nn.Conv2d):
nn.init.kaiming_normal_(w.weight, mode='fan_out')
if w.bias is not None:
nn.init.zeros_(w.bias)
elif isinstance(w, nn.ConvTranspose2d):
nn.init.kaiming_normal_(w.weight, mode='fan_in')
elif isinstance(w, nn.BatchNorm2d):
nn.init.ones_(w.weight)
nn.init.zeros_(w.bias)
def forward(self, x):
# down
down_1 = self.down_1(x)
down_2 = self.down_2(down_1)
down_3 = self.down_3(down_2)
down_4 = self.down_4(down_3)
down_5 = self.down_5(down_4)
down_6 = self.down_6(down_5)
down_7 = self.down_7(down_6)
down_8 = self.down_8(down_7)
# up
up_1 = self.up_1(down_8)
up_2 = self.up_2(torch.cat([up_1, down_7], dim=1))
up_3 = self.up_3(torch.cat([up_2, down_6], dim=1))
up_4 = self.up_4(torch.cat([up_3, down_5], dim=1))
up_5 = self.up_5(torch.cat([up_4, down_4], dim=1))
up_6 = self.up_6(torch.cat([up_5, down_3], dim=1))
up_7 = self.up_7(torch.cat([up_6, down_2], dim=1))
out = self.last_Conv(torch.cat([up_7, down_1], dim=1))
return out
# 256*256
class pix2pixD_256(nn.Module):
def __init__(self):
super(pix2pixD_256, self).__init__()
# 定义基本的卷积\bn\relu
def base_Conv_bn_lkrl(in_channels, out_channels, stride):
return nn.Sequential(
nn.Conv2d(in_channels, out_channels, 4, stride, 1),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
D_dic = OrderedDict()
in_channels = 6
out_channels = 64
for i in range(4):
if i < 3:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
else:
D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
in_channels = out_channels
out_channels *= 2
D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)}) # [batch,1,30,30]
self.D_model = nn.Sequential(D_dic)
def forward(self, x1, x2):
in_x = torch.cat([x1, x2], dim=1)
return self.D_model(in_x)
if __name__ == '__main__':
G = pix2pixG_256().to('cpu')
summary(G, (3, 256, 256))
降采样部分 (downsample
类)
- 作用:实现特征图的降采样,用于提取高层语义特征,同时减小特征图尺寸。
- 结构:
- 激活函数:
LeakyReLU
,以 0.2 的负斜率处理负值。 - 卷积层:
Conv2d
,用 4x4 核心,步幅为 2,填充为 1,进行空间降采样。 - 批归一化:
BatchNorm2d
,加速训练和稳定模型。
- 激活函数:
- 前向传播:按顺序调用激活、卷积和归一化操作。
上采样部分 (upsample
类)
- 作用:对特征图进行上采样,用于恢复空间分辨率。
- 结构:
- 激活函数:
ReLU
,对特征进行正值激活。 - 反卷积:
ConvTranspose2d
,用 4x4 核心,步幅为 2,填充为 1,进行上采样。 - 批归一化:
BatchNorm2d
。 - 随机失活(可选):
Dropout
,根据drop_out
参数决定是否引入,以增加鲁棒性。
- 激活函数:
- 前向传播:按顺序调用激活、反卷积、归一化,以及随机失活操作(若有)。
生成器 (pix2pixG_128
和 pix2pixG_256
类)
pix2pixG_128
类
- 输入大小:128x128 的图像。
- 作用:构建 U-Net 风格的生成器,用于生成与输入图像相对应的输出。
- 降采样部分:
down_1
是输入的第一层卷积,用于提取低级特征。- 后续降采样层(
down_2
到down_7
)通过downsample
实现,特征逐步被压缩。
- 上采样部分:
- 从最小的特征图开始逐层上采样(
up_1
到up_6
),并通过跳跃连接(torch.cat
)与降采样层输出进行拼接,实现特征融合。
- 从最小的特征图开始逐层上采样(
- 输出部分:
last_Conv
使用反卷积和Tanh
激活函数生成最终图像。
- 权重初始化:通过
init_weight
函数对卷积和归一化层的权重进行初始化。 - 前向传播:
- 依次经过降采样和上采样过程,跳跃连接用于保留低级细节信息。
pix2pixG_256
类
- 输入大小:256x256 的图像。
- 区别:
- 增加了一层降采样(
down_8
)和对应的上采样(up_7
)。 - 适配更大的输入图像。
- 增加了一层降采样(
判别器 (pix2pixD_128
类)
-
作用:用于判断输入图像是否为生成器生成的图像。
-
结构:
- 层数:包括 4 个卷积层和一个最后的分类层(
last_layer
)。 - 每层卷积都包含:
Conv2d
->BatchNorm2d
->LeakyReLU
。 - 最后一层不使用归一化,只输出一个 1x1 的预测值。
- 层数:包括 4 个卷积层和一个最后的分类层(
-
输入:
- 两张图像(通常是输入图像和生成图像)。
- 使用
torch.cat
进行通道维度拼接。
-
输出:
- 一个 1x1 的值,用于表征真假。
-
该代码实现了基于 U-Net 风格的 Pix2Pix 生成器(支持 128x128 和 256x256 输入)和一个 PatchGAN 判别器。
-
生成器和判别器均结合了卷积、归一化和激活函数,生成器中包含跳跃连接以增强低级信息的保留。
-
整体框架为 Pix2Pix 的典型实现,适用于图像到图像的转换任务(如风格迁移、图像修复等)。
import torchvision
from tqdm import tqdm
import torch
import os
def train_one_epoch(G, D, train_loader, optim_G, optim_D, writer, loss, device, plot_every, epoch, l1_loss):
pd = tqdm(train_loader)
loss_D, loss_G = 0, 0
step = 0
G.train()
D.train()
for idx, data in enumerate(pd):
in_img = data[0].to(device)
real_img = data[1].to(device)
# 先训练D
fake_img = G(in_img)
D_fake_out = D(fake_img.detach(), in_img).squeeze()
D_real_out = D(real_img, in_img).squeeze()
ls_D1 = loss(D_fake_out, torch.zeros(D_fake_out.size()).cuda())
ls_D2 = loss(D_real_out, torch.ones(D_real_out.size()).cuda())
ls_D = (ls_D1 + ls_D2) * 0.5
optim_D.zero_grad()
ls_D.backward()
optim_D.step()
# 再训练G
fake_img = G(in_img)
D_fake_out = D(fake_img, in_img).squeeze()
ls_G1 = loss(D_fake_out, torch.ones(D_fake_out.size()).cuda())
ls_G2 = l1_loss(fake_img, real_img)
ls_G = ls_G1 + ls_G2 * 100
optim_G.zero_grad()
ls_G.backward()
optim_G.step()
loss_D += ls_D
loss_G += ls_G
pd.desc = 'train_{} G_loss: {} D_loss: {}'.format(epoch, ls_G.item(), ls_D.item())
# 绘制训练结果
if idx % plot_every == 0:
writer.add_images(tag='train_epoch_{}'.format(epoch), img_tensor=0.5 * (fake_img + 1), global_step=step)
step += 1
mean_lsG = loss_G / len(train_loader)
mean_lsD = loss_D / len(train_loader)
return mean_lsG, mean_lsD
@torch.no_grad()
def val(G, D, val_loader, loss, device, l1_loss, epoch):
pd = tqdm(val_loader)
loss_D, loss_G = 0, 0
G.eval()
D.eval()
all_loss = 10000
for idx, item in enumerate(pd):
in_img = item[0].to(device)
real_img = item[1].to(device)
fake_img = G(in_img)
D_fake_out = D(fake_img, in_img).squeeze()
ls_D1 = loss(D_fake_out, torch.zeros(D_fake_out.size()).cuda())
ls_D = ls_D1 * 0.5
ls_G1 = loss(D_fake_out, torch.ones(D_fake_out.size()).cuda())
ls_G2 = l1_loss(fake_img, real_img)
ls_G = ls_G1 + ls_G2 * 100
loss_G += ls_G
loss_D += ls_D
pd.desc = 'val_{}: G_loss:{} D_Loss:{}'.format(epoch, ls_G.item(), ls_D.item())
# 保存最好的结果
all_ls = ls_G + ls_D
if all_ls < all_loss:
all_loss = all_ls
best_image = fake_img
result_img = (best_image + 1) * 0.5
if not os.path.exists('./results'):
os.mkdir('./results')
torchvision.utils.save_image(result_img, './results/val_epoch{}.jpg'.format(epoch))
from pix2Topix import pix2pixG_256
import torch
import torchvision.transforms as transform
import matplotlib.pyplot as plt
import cv2
from PIL import Image
def run_test(img_path, original_path, compare_path):
# 加载生成图像
if img_path.endswith('.png'):
img = cv2.imread(img_path)
img = img[:, :, ::-1] # BGR -> RGB
else:
img = Image.open(img_path)
# 对输入图片进行预处理
transforms = transform.Compose([
transform.ToTensor(),
transform.Resize((256, 256), antialias=True),
transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
img = transforms(img.copy())
img = img[None].to('cpu') # 移到 CPU
# 实例化生成器模型
G = pix2pixG_256().to('cpu') # 移到 CPU
# 加载预训练权重
ckpt = torch.load('weights/pix2pix_256.pth', map_location='cpu') # 加载到 CPU
G.load_state_dict(ckpt['G_model'], strict=False)
G.eval()
out = G(img)[0]
out = out.permute(1, 2, 0) # 转换为 HWC 格式
out = (0.5 * (out + 1)).cpu().detach().numpy() # 反归一化
# 加载原始图片并调整为 256x256
original_img = Image.open(original_path)
original_img = original_img.resize((256, 256))
original_img = transform.ToTensor()(original_img).permute(1, 2, 0).numpy() # 转换为 HWC 格式
# original_img = (0.5 * (original_img + 1)) # 反归一化
# 加载另一张原图进行对比
compare_img = Image.open(compare_path)
compare_img = compare_img.resize((256, 256))
compare_img = transform.ToTensor()(compare_img).permute(1, 2, 0).numpy() # 转换为 HWC 格式
compare_img = (0.5 * (compare_img + 1)) # 反归一化
# 绘制对比图
plt.figure(figsize=(15, 5))
# 显示第一张原始图片
plt.subplot(1, 3, 1)
plt.imshow(original_img) # 使用反归一化后的图像
plt.title('Original Image 1')
plt.axis('off')
# 显示第二张原始图片
plt.subplot(1, 3, 2)
plt.imshow(compare_img) # 使用反归一化后的图像
plt.title('Original Image 2')
plt.axis('off')
# 显示生成的图片
plt.subplot(1, 3, 3)
plt.imshow(out)
plt.title('Generated Image')
plt.axis('off')
plt.show()
if __name__ == '__main__':
run_test('./base/cmp_b0007.png', './base/cmp_b0007.jpg', './base/cmp_b0007.png')
测试效果:
原图-语义分割图-生成图
小结
生成器结构
判别器结构
优化目标
- 点赞
- 收藏
- 关注作者
评论(0)