对抗生成网络:指马为斑马

举报
darkpard 发表于 2022/09/18 10:37:14 2022/09/18
【摘要】 生成对抗网络GAN是去年以来比较火的一个技术,它通过一个生成网络来形成新的内容,再通过一个判别网络来判断生成的内容是否是想要的内容。一个简单的实现如下(非原创):# ResNetGeneratorimport torchimport torch.nn as nnclass ResNetBlock(nn.Module): def __init__(self, dim): s...

生成对抗网络GAN是去年以来比较火的一个技术,它通过一个生成网络来形成新的内容,再通过一个判别网络来判断生成的内容是否是想要的内容。

一个简单的实现如下(非原创):

    # ResNetGenerator
    import torch
    import torch.nn as nn
    
    class ResNetBlock(nn.Module):
    
        def __init__(self, dim):
            super(ResNetBlock, self).__init__()
            self.conv_block = self.build_conv_block(dim)
    
        def build_conv_block(self, dim):
            conv_block = []
    
            conv_block += [nn.ReflectionPad2d(1)]
    
            conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                           nn.InstanceNorm2d(dim),
                           nn.ReLU(True)]
    
            conv_block += [nn.ReflectionPad2d(1)]
    
            conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=0, bias=True),
                           nn.InstanceNorm2d(dim)]
    
            return nn.Sequential(*conv_block)
    
        def forward(self, x):
            out = x + self.conv_block(x)
            return out
    
    
    class ResNetGenerator(nn.Module):
    
        def __init__(self, input_nc=3, output_nc=3, ngf=64, n_blocks=9): 
    
            assert(n_blocks >= 0)
            super(ResNetGenerator, self).__init__()
    
            self.input_nc = input_nc
            self.output_nc = output_nc
            self.ngf = ngf
    
            model = [nn.ReflectionPad2d(3),
                     nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=True),
                     nn.InstanceNorm2d(ngf),
                     nn.ReLU(True)]
    
            n_downsampling = 2
            for i in range(n_downsampling):
                mult = 2**i
                model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3,
                                    stride=2, padding=1, bias=True),
                          nn.InstanceNorm2d(ngf * mult * 2),
                          nn.ReLU(True)]
    
            mult = 2**n_downsampling
            for i in range(n_blocks):
                model += [ResNetBlock(ngf * mult)]
    
            for i in range(n_downsampling):
                mult = 2**(n_downsampling - i)
                model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
                                             kernel_size=3, stride=2,
                                             padding=1, output_padding=1,
                                             bias=True),
                          nn.InstanceNorm2d(int(ngf * mult / 2)),
                          nn.ReLU(True)]
    
            model += [nn.ReflectionPad2d(3)]
            model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
            model += [nn.Tanh()]
    
            self.model = nn.Sequential(*model)
    
        def forward(self, input):
            return self.model(input)

    我们可以用它来实现把马变成斑马,首先创建一个实例

      netG = ResNetGenerator()

      然后下载一个训练好的模型参数给我们的netG

        !git clone https://github.com/deep-learning-with-pytorch/dlwpt-code.git
        model_path = 'dlwpt-code/data/p1ch2/horse2zebra_0.4.0.pth'
        model_data = torch.load(model_path
        netG.load_state_dict(model_data)

        将模型调整为评估模式

          netG.eval()

          随便找一张马的图片,读取图片

            from PIL import Image
            from torchvision import transforms
            img = Image.open("horse.jpg")
            img

            图片

            对图片进行一些处理

              preprocess = transforms.Compose([transforms.Resize(256), 
                                               transforms.ToTensor()])
              img_t = preprocess(img)
              batch_t = torch.unsqueeze(img_t, 0
              batch_out = netG(batch_t)

              指马为斑马

                out_t = (batch_out.data.squeeze() + 1.0) / 2.0
                out_img = transforms.ToPILImage()(out_t)
                out_img

                图片

                【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
                • 点赞
                • 收藏
                • 关注作者

                评论(0

                0/1000
                抱歉,系统识别当前为高风险访问,暂不支持该操作

                全部回复

                上滑加载中

                设置昵称

                在此一键设置昵称,即可参与社区互动!

                *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

                *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。