cartoongan 图像动漫化

举报
HWCloudAI 发表于 2022/12/09 16:08:30 2022/12/09
【摘要】 cartoongan 图像动漫化本案例是 CartoonGAN: Generative Adversarial Networks for Photo Cartoonization的论文复习案例拷贝数据之后,将你想动漫化的图像放到cartoongan-pytorch/test_img/文件夹下,运行后面代码即可可以切换不同生成风格,Hosoda/Shinkai/Paprika/Hayao参考...

cartoongan 图像动漫化

本案例是 CartoonGAN: Generative Adversarial Networks for Photo Cartoonization
的论文复习案例

拷贝数据之后,将你想动漫化的图像放到cartoongan-pytorch/test_img/文件夹下,运行后面代码即可

可以切换不同生成风格,Hosoda/Shinkai/Paprika/Hayao

参考:https://github.com/venture-anime/cartoongan-pytorch

拷贝代码和数据

import moxing as mox
mox.file.copy_parallel('obs://obs-aigallery-zc/clf/code/cartoongan-pytorch','cartoongan-pytorch')

%cd cartoongan-pytorch

运行代码

import torch
import os
import numpy as np
import torchvision.utils as vutils

from PIL import Image
import torchvision.transforms as transforms
from torch.autograd import Variable

import matplotlib.pyplot as plt

from network.Transformer import Transformer
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--input_dir", default="test_img")
parser.add_argument("--load_size", default=1280)
parser.add_argument("--model_path", default="./pretrained_model")
parser.add_argument("--style", default="Hosoda")  # 在这里切换风格, Hosoda/Shinkai/Paprika/Hayao
parser.add_argument("--output_dir", default="test_output")
parser.add_argument("--gpu", type=int, default=0)

# opt = parser.parse_args()
opt, unknown = parser.parse_known_args()
valid_ext = [".jpg", ".png", ".jpeg"]

# setup
if not os.path.exists(opt.input_dir):
    os.makedirs(opt.input_dir)
if not os.path.exists(opt.output_dir):
    os.makedirs(opt.output_dir)

# load pretrained model
model = Transformer()
model.load_state_dict(
    torch.load(os.path.join(opt.model_path, opt.style + "_net_G_float.pth"))
)
model.eval()

disable_gpu = opt.gpu == -1 or not torch.cuda.is_available()

if disable_gpu:
    print("CPU mode")
    model.float()
else:
    print("GPU mode")
    model.cuda()

for i,files in enumerate(os.listdir(opt.input_dir)):
    ext = os.path.splitext(files)[1]
    if ext not in valid_ext:
        continue
    # load image
    input_image = Image.open(os.path.join(opt.input_dir, files)).convert("RGB")
    input_image = np.asarray(input_image)
    # RGB -> BGR
    input_image = input_image[:, :, [2, 1, 0]]
    input_image = transforms.ToTensor()(input_image).unsqueeze(0)
    # preprocess, (-1, 1)
    input_image = -1 + 2 * input_image
    if disable_gpu:
        input_image = Variable(input_image).float()
    else:
        input_image = Variable(input_image).cuda()

    # forward
    output_image = model(input_image)
    output_image = output_image[0]
    # BGR -> RGB
    output_image = output_image[[2, 1, 0], :, :]
    output_image = output_image.data.cpu().float() * 0.5 + 0.5
    # save
    vutils.save_image(
        output_image,
        os.path.join(opt.output_dir, files[:-4] + "_" + opt.style + ".jpg"),
    )
    
    original = np.array(Image.open(os.path.join(opt.input_dir, files)))
    style = np.array(Image.open(os.path.join(opt.output_dir, files[:-4] + "_" + opt.style + ".jpg")))
    
    plt.figure(figsize=(20,20)) # 显示缩放比例
    plt.subplot(i+1,2,1)
    plt.imshow(original)
    plt.subplot(i+1,2,2)
    plt.imshow(style)
    plt.show()

print("Done!")

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。