【图像翻译】Pix2Pix——基于GAN的图像风格迁移模型【论文解读+代码实验】

举报
柠檬味拥抱 发表于 2024/11/29 16:50:38 2024/11/29
【摘要】 【图像翻译】Pix2Pix——基于GAN的图像风格迁移模型【论文解读+代码实验】CVPR 2017论文pix2pix,在image2image的任务之中具有很好的效果。Pix2Pix 基于 GAN 架构,利用成对的图片进行图像翻译,即输入为同一张图片的两种不同风格,可用于进行风格迁移。标签地图——照片重建的图像——边缘图像黑白图像——彩色图像 传统GAN原理 Pix2Pix 的结构一个 G...

【图像翻译】Pix2Pix——基于GAN的图像风格迁移模型【论文解读+代码实验】

image.png

CVPR 2017论文pix2pix,在image2image的任务之中具有很好的效果。

Pix2Pix 基于 GAN 架构,利用成对的图片进行图像翻译,即输入为同一张图片的两种不同风格,可用于进行风格迁移。

image.png

标签地图——照片

重建的图像——边缘图像

黑白图像——彩色图像

传统GAN原理

image.png

Pix2Pix 的结构

一个 GAN 结构的网络至少由两部分构成:生成器模型(Generative Model)与判别器模型(Discriminative Model)。GAN 通过两个模块的互相博弈学习产生相当好的输出。一个优秀的 GAN 需要有良好的训练方法,否则可能由于神经网络模型的自由性而导致输出不理想。

  1. 生成器(Generator)

image.png

Pix2Pix 的生成器采用了一个U-Net架构。U-Net 是一种经典的编码器-解码器结构,具有跳跃连接(skip connections),以保留高分辨率的局部特征,增强图像生成的质量。具体来说:

  • 编码器:一系列卷积层将输入图像编码为低分辨率的特征表示。
  • 解码器:一系列反卷积层将特征表示解码回原始图像的分辨率。
  • 跳跃连接:编码器和解码器的对应层之间通过直接连接传递特征,保留了输入图像的细节信息。

这使得生成器能够很好地处理结构化的图像转换任务,例如从边缘生成图像或从黑白照片生成彩色照片。

image.png

生成器优化目标

image.png

优化目标的定义,包括2部分,一部分是真实样本(x,y),第二部分x和生成的G(x,z)。组成一个标准的GAN损失。

  1. 判别器(Discriminator)

Pix2Pix 的判别器采用了PatchGAN架构,而不是传统的全局判别器。

  • PatchGAN 判别器:对输入图像的局部小块(Patch)进行判别,而不是对整张图像进行判别。这种设计的好处是可以关注到细节纹理和局部一致性。
  • 判别器的目标是区分生成的图像与真实图像是否符合条件输入。

这种设计使 Pix2Pix 更加适用于像素级的图像对齐任务,因为 PatchGAN 只关注局部的图像质量,而不是全局的风格或分布。

PatchGAN其实指的是GAN的判别器,将判别器换成了全卷积网络.

image.png

复现练习

数据集

数据集采用的是facades数据集,训练样本是成对匹配的(其中fecades图片均以png格式存储,对应真实样本以jpg存储),共有606张图片,训练过程中将数据集按8:2的比例划分训练集和验证集。

image.png

代码

pix2Topix.py

import torch.nn as nn
from torchsummary import summary
import torch
from collections import OrderedDict


# 定义降采样部分
class downsample(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(downsample, self).__init__()
        self.down = nn.Sequential(
            nn.LeakyReLU(0.2, True),
            nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels)
        )

    def forward(self, x):
        return self.down(x)


# 定义上采样部分
class upsample(nn.Module):
    def __init__(self, in_channels, out_channels, drop_out=False):
        super(upsample, self).__init__()
        self.up = nn.Sequential(
            nn.ReLU(True),
            nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=4, stride=2, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.Dropout(0.5) if drop_out else nn.Identity()
        )

    def forward(self, x):
        return self.up(x)


# ---------------------------------------------------------------------------------
# 定义pix_G  =>input 128*128
class pix2pixG_128(nn.Module):
    def __init__(self):
        super(pix2pixG_128, self).__init__()
        # down sample
        self.down_1 = nn.Conv2d(3, 64, 4, 2, 1)  # [batch,3,128,128]=>[batch,64,64,64]
        for i in range(7):
            if i == 0:
                self.down_2 = downsample(64, 128)  # [batch,64,64,64]=>[batch,128,32,32]
                self.down_3 = downsample(128, 256)  # [batch,128,32,32]=>[batch,256,16,16]
                self.down_4 = downsample(256, 512)  # [batch,256,16,16]=>[batch,512,8,8]
                self.down_5 = downsample(512, 512)  # [batch,512,8,8]=>[batch,512,4,4]
                self.down_6 = downsample(512, 512)  # [batch,512,4,4]=>[batch,512,2,2]
                self.down_7 = downsample(512, 512)  # [batch,512,2,2]=>[batch,512,1,1]

        # up_sample
        for i in range(7):
            if i == 0:
                self.up_1 = upsample(512, 512)  # [batch,512,1,1]=>[batch,512,2,2]
                self.up_2 = upsample(1024, 512, drop_out=True)  # [batch,1024,2,2]=>[batch,512,4,4]
                self.up_3 = upsample(1024, 512, drop_out=True)  # [batch,1024,4,4]=>[batch,512,8,8]
                self.up_4 = upsample(1024, 256)  # [batch,1024,8,8]=>[batch,256,16,16]
                self.up_5 = upsample(512, 128)  # [batch,512,16,16]=>[batch,128,32,32]
                self.up_6 = upsample(256, 64)  # [batch,256,32,32]=>[batch,64,64,64]

        self.last_Conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

        self.init_weight()

    def init_weight(self):
        for w in self.modules():
            if isinstance(w, nn.Conv2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_out')
                if w.bias is not None:
                    nn.init.zeros_(w.bias)
            elif isinstance(w, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_in')
            elif isinstance(w, nn.BatchNorm2d):
                nn.init.ones_(w.weight)
                nn.init.zeros_(w.bias)

    def forward(self, x):
        # down
        down_1 = self.down_1(x)
        down_2 = self.down_2(down_1)
        down_3 = self.down_3(down_2)
        down_4 = self.down_4(down_3)
        down_5 = self.down_5(down_4)
        down_6 = self.down_6(down_5)
        down_7 = self.down_7(down_6)
        # up
        up_1 = self.up_1(down_7)
        up_2 = self.up_2(torch.cat([up_1, down_6], dim=1))
        up_3 = self.up_3(torch.cat([up_2, down_5], dim=1))
        up_4 = self.up_4(torch.cat([up_3, down_4], dim=1))
        up_5 = self.up_5(torch.cat([up_4, down_3], dim=1))
        up_6 = self.up_6(torch.cat([up_5, down_2], dim=1))
        out = self.last_Conv(torch.cat([up_6, down_1], dim=1))
        return out


# 定义pix_D_128    => input 128*128
class pix2pixD_128(nn.Module):
    def __init__(self):
        super(pix2pixD_128, self).__init__()

        # 定义基本的卷积\bn\relu
        def base_Conv_bn_lkrl(in_channels, out_channels, stride):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )

        D_dic = OrderedDict()
        in_channels = 6
        out_channels = 64
        for i in range(4):
            if i < 3:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
            else:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
            in_channels = out_channels
            out_channels *= 2
        D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)})  # [batch,1,14,14]
        self.D_model = nn.Sequential(D_dic)

    def forward(self, x1, x2):
        in_x = torch.cat([x1, x2], dim=1)
        return self.D_model(in_x)


# ---------------------------------------------------------------------------------
# 256*256
class pix2pixG_256(nn.Module):
    def __init__(self):
        super(pix2pixG_256, self).__init__()
        # down sample
        self.down_1 = nn.Conv2d(3, 64, 4, 2, 1)  # [batch,3,256,256]=>[batch,64,128,128]
        for i in range(7):
            if i == 0:
                self.down_2 = downsample(64, 128)  # [batch,64,128,128]=>[batch,128,64,64]
                self.down_3 = downsample(128, 256)  # [batch,128,64,64]=>[batch,256,32,32]
                self.down_4 = downsample(256, 512)  # [batch,256,32,32]=>[batch,512,16,16]
                self.down_5 = downsample(512, 512)  # [batch,512,16,16]=>[batch,512,8,8]
                self.down_6 = downsample(512, 512)  # [batch,512,8,8]=>[batch,512,4,4]
                self.down_7 = downsample(512, 512)  # [batch,512,4,4]=>[batch,512,2,2]
                self.down_8 = downsample(512, 512)  # [batch,512,2,2]=>[batch,512,1,1]

        # up_sample
        for i in range(7):
            if i == 0:
                self.up_1 = upsample(512, 512)  # [batch,512,1,1]=>[batch,512,2,2]
                self.up_2 = upsample(1024, 512, drop_out=True)  # [batch,1024,2,2]=>[batch,512,4,4]
                self.up_3 = upsample(1024, 512, drop_out=True)  # [batch,1024,4,4]=>[batch,512,8,8]
                self.up_4 = upsample(1024, 512)  # [batch,1024,8,8]=>[batch,512,16,16]
                self.up_5 = upsample(1024, 256)  # [batch,1024,16,16]=>[batch,256,32,32]
                self.up_6 = upsample(512, 128)  # [batch,512,32,32]=>[batch,128,64,64]
                self.up_7 = upsample(256, 64)  # [batch,256,64,64]=>[batch,64,128,128]

        self.last_Conv = nn.Sequential(
            nn.ConvTranspose2d(in_channels=128, out_channels=3, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        )

        self.init_weight()

    def init_weight(self):
        for w in self.modules():
            if isinstance(w, nn.Conv2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_out')
                if w.bias is not None:
                    nn.init.zeros_(w.bias)
            elif isinstance(w, nn.ConvTranspose2d):
                nn.init.kaiming_normal_(w.weight, mode='fan_in')
            elif isinstance(w, nn.BatchNorm2d):
                nn.init.ones_(w.weight)
                nn.init.zeros_(w.bias)

    def forward(self, x):
        # down
        down_1 = self.down_1(x)
        down_2 = self.down_2(down_1)
        down_3 = self.down_3(down_2)
        down_4 = self.down_4(down_3)
        down_5 = self.down_5(down_4)
        down_6 = self.down_6(down_5)
        down_7 = self.down_7(down_6)
        down_8 = self.down_8(down_7)
        # up
        up_1 = self.up_1(down_8)
        up_2 = self.up_2(torch.cat([up_1, down_7], dim=1))
        up_3 = self.up_3(torch.cat([up_2, down_6], dim=1))
        up_4 = self.up_4(torch.cat([up_3, down_5], dim=1))
        up_5 = self.up_5(torch.cat([up_4, down_4], dim=1))
        up_6 = self.up_6(torch.cat([up_5, down_3], dim=1))
        up_7 = self.up_7(torch.cat([up_6, down_2], dim=1))
        out = self.last_Conv(torch.cat([up_7, down_1], dim=1))
        return out


# 256*256
class pix2pixD_256(nn.Module):
    def __init__(self):
        super(pix2pixD_256, self).__init__()

        # 定义基本的卷积\bn\relu
        def base_Conv_bn_lkrl(in_channels, out_channels, stride):
            return nn.Sequential(
                nn.Conv2d(in_channels, out_channels, 4, stride, 1),
                nn.BatchNorm2d(out_channels),
                nn.LeakyReLU(0.2)
            )

        D_dic = OrderedDict()
        in_channels = 6
        out_channels = 64
        for i in range(4):
            if i < 3:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
            else:
                D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
            in_channels = out_channels
            out_channels *= 2
        D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)})  # [batch,1,30,30]
        self.D_model = nn.Sequential(D_dic)

    def forward(self, x1, x2):
        in_x = torch.cat([x1, x2], dim=1)
        return self.D_model(in_x)


if __name__ == '__main__':
    G = pix2pixG_256().to('cpu')
    summary(G, (3, 256, 256))

降采样部分 (downsample 类)

  • 作用:实现特征图的降采样,用于提取高层语义特征,同时减小特征图尺寸。
  • 结构:
    • 激活函数:LeakyReLU,以 0.2 的负斜率处理负值。
    • 卷积层:Conv2d,用 4x4 核心,步幅为 2,填充为 1,进行空间降采样。
    • 批归一化:BatchNorm2d,加速训练和稳定模型。
  • 前向传播:按顺序调用激活、卷积和归一化操作。

上采样部分 (upsample 类)

  • 作用:对特征图进行上采样,用于恢复空间分辨率。
  • 结构:
    • 激活函数:ReLU,对特征进行正值激活。
    • 反卷积:ConvTranspose2d,用 4x4 核心,步幅为 2,填充为 1,进行上采样。
    • 批归一化:BatchNorm2d
    • 随机失活(可选):Dropout,根据 drop_out 参数决定是否引入,以增加鲁棒性。
  • 前向传播:按顺序调用激活、反卷积、归一化,以及随机失活操作(若有)。

生成器 (pix2pixG_128pix2pixG_256 类)

pix2pixG_128

  • 输入大小:128x128 的图像。
  • 作用:构建 U-Net 风格的生成器,用于生成与输入图像相对应的输出。
  • 降采样部分:
    • down_1 是输入的第一层卷积,用于提取低级特征。
    • 后续降采样层(down_2down_7)通过 downsample 实现,特征逐步被压缩。
  • 上采样部分:
    • 从最小的特征图开始逐层上采样(up_1up_6),并通过跳跃连接(torch.cat)与降采样层输出进行拼接,实现特征融合。
  • 输出部分:
    • last_Conv 使用反卷积和 Tanh 激活函数生成最终图像。
  • 权重初始化:通过 init_weight 函数对卷积和归一化层的权重进行初始化。
  • 前向传播:
    • 依次经过降采样和上采样过程,跳跃连接用于保留低级细节信息。

pix2pixG_256

  • 输入大小:256x256 的图像。
  • 区别:
    • 增加了一层降采样(down_8)和对应的上采样(up_7)。
    • 适配更大的输入图像。

判别器 (pix2pixD_128 类)

  • 作用:用于判断输入图像是否为生成器生成的图像。

  • 结构:

    • 层数:包括 4 个卷积层和一个最后的分类层(last_layer)。
    • 每层卷积都包含:Conv2d -> BatchNorm2d -> LeakyReLU
    • 最后一层不使用归一化,只输出一个 1x1 的预测值。
  • 输入:

    • 两张图像(通常是输入图像和生成图像)。
    • 使用 torch.cat 进行通道维度拼接。
  • 输出:

    • 一个 1x1 的值,用于表征真假。
  • 该代码实现了基于 U-Net 风格的 Pix2Pix 生成器(支持 128x128 和 256x256 输入)和一个 PatchGAN 判别器。

  • 生成器和判别器均结合了卷积、归一化和激活函数,生成器中包含跳跃连接以增强低级信息的保留。

  • 整体框架为 Pix2Pix 的典型实现,适用于图像到图像的转换任务(如风格迁移、图像修复等)。

train.py

import torchvision
from tqdm import tqdm
import torch
import os
 
 
def train_one_epoch(G, D, train_loader, optim_G, optim_D, writer, loss, device, plot_every, epoch, l1_loss):
    pd = tqdm(train_loader)
    loss_D, loss_G = 0, 0
    step = 0
    G.train()
    D.train()
    for idx, data in enumerate(pd):
        in_img = data[0].to(device)
        real_img = data[1].to(device)
        # 先训练D
        fake_img = G(in_img)
        D_fake_out = D(fake_img.detach(), in_img).squeeze()
        D_real_out = D(real_img, in_img).squeeze()
        ls_D1 = loss(D_fake_out, torch.zeros(D_fake_out.size()).cuda())
        ls_D2 = loss(D_real_out, torch.ones(D_real_out.size()).cuda())
        ls_D = (ls_D1 + ls_D2) * 0.5
 
        optim_D.zero_grad()
        ls_D.backward()
        optim_D.step()
 
        # 再训练G
        fake_img = G(in_img)
        D_fake_out = D(fake_img, in_img).squeeze()
        ls_G1 = loss(D_fake_out, torch.ones(D_fake_out.size()).cuda())
        ls_G2 = l1_loss(fake_img, real_img)
        ls_G = ls_G1 + ls_G2 * 100
 
        optim_G.zero_grad()
        ls_G.backward()
        optim_G.step()
 
        loss_D += ls_D
        loss_G += ls_G
 
        pd.desc = 'train_{} G_loss: {} D_loss: {}'.format(epoch, ls_G.item(), ls_D.item())
        # 绘制训练结果
        if idx % plot_every == 0:
            writer.add_images(tag='train_epoch_{}'.format(epoch), img_tensor=0.5 * (fake_img + 1), global_step=step)
            step += 1
    mean_lsG = loss_G / len(train_loader)
    mean_lsD = loss_D / len(train_loader)
    return mean_lsG, mean_lsD
 
 
@torch.no_grad()
def val(G, D, val_loader, loss, device, l1_loss, epoch):
    pd = tqdm(val_loader)
    loss_D, loss_G = 0, 0
    G.eval()
    D.eval()
    all_loss = 10000
    for idx, item in enumerate(pd):
        in_img = item[0].to(device)
        real_img = item[1].to(device)
        fake_img = G(in_img)
        D_fake_out = D(fake_img, in_img).squeeze()
        ls_D1 = loss(D_fake_out, torch.zeros(D_fake_out.size()).cuda())
        ls_D = ls_D1 * 0.5
        ls_G1 = loss(D_fake_out, torch.ones(D_fake_out.size()).cuda())
        ls_G2 = l1_loss(fake_img, real_img)
        ls_G = ls_G1 + ls_G2 * 100
        loss_G += ls_G
        loss_D += ls_D
        pd.desc = 'val_{}: G_loss:{} D_Loss:{}'.format(epoch, ls_G.item(), ls_D.item())
 
        # 保存最好的结果
        all_ls = ls_G + ls_D
        if all_ls < all_loss:
            all_loss = all_ls
            best_image = fake_img
    result_img = (best_image + 1) * 0.5
    if not os.path.exists('./results'):
        os.mkdir('./results')
 
    torchvision.utils.save_image(result_img, './results/val_epoch{}.jpg'.format(epoch))

test.py

from pix2Topix import pix2pixG_256
import torch
import torchvision.transforms as transform
import matplotlib.pyplot as plt
import cv2
from PIL import Image

def run_test(img_path, original_path, compare_path):
    # 加载生成图像
    if img_path.endswith('.png'):
        img = cv2.imread(img_path)
        img = img[:, :, ::-1]  # BGR -> RGB
    else:
        img = Image.open(img_path)

    # 对输入图片进行预处理
    transforms = transform.Compose([
        transform.ToTensor(),
        transform.Resize((256, 256), antialias=True),
        transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
    ])
    img = transforms(img.copy())
    img = img[None].to('cpu')  # 移到 CPU

    # 实例化生成器模型
    G = pix2pixG_256().to('cpu')  # 移到 CPU
    # 加载预训练权重
    ckpt = torch.load('weights/pix2pix_256.pth', map_location='cpu')  # 加载到 CPU
    G.load_state_dict(ckpt['G_model'], strict=False)

    G.eval()
    out = G(img)[0]
    out = out.permute(1, 2, 0)  # 转换为 HWC 格式
    out = (0.5 * (out + 1)).cpu().detach().numpy()  # 反归一化

    # 加载原始图片并调整为 256x256
    original_img = Image.open(original_path)
    original_img = original_img.resize((256, 256))
    original_img = transform.ToTensor()(original_img).permute(1, 2, 0).numpy()  # 转换为 HWC 格式
    # original_img = (0.5 * (original_img + 1))  # 反归一化

    # 加载另一张原图进行对比
    compare_img = Image.open(compare_path)
    compare_img = compare_img.resize((256, 256))
    compare_img = transform.ToTensor()(compare_img).permute(1, 2, 0).numpy()  # 转换为 HWC 格式
    compare_img = (0.5 * (compare_img + 1))  # 反归一化

    # 绘制对比图
    plt.figure(figsize=(15, 5))

    # 显示第一张原始图片
    plt.subplot(1, 3, 1)
    plt.imshow(original_img)  # 使用反归一化后的图像
    plt.title('Original Image 1')
    plt.axis('off')

    # 显示第二张原始图片
    plt.subplot(1, 3, 2)
    plt.imshow(compare_img)  # 使用反归一化后的图像
    plt.title('Original Image 2')
    plt.axis('off')

    # 显示生成的图片
    plt.subplot(1, 3, 3)
    plt.imshow(out)
    plt.title('Generated Image')
    plt.axis('off')

    plt.show()


if __name__ == '__main__':
    run_test('./base/cmp_b0007.png', './base/cmp_b0007.jpg', './base/cmp_b0007.png')

测试效果:

原图-语义分割图-生成图
image.png

image.png

image.png

image.png

小结

生成器结构

判别器结构

优化目标

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。