天空背景转换

举报
HWCloudAI 发表于 2022/12/01 11:02:21 2022/12/01
【摘要】 这个 notebook 基于预印本论文「Castle in the Sky: Dynamic Sky Replacement and Harmonization in Videos, arXiv:2010.11800.」提供了最基本的视频天空替换的可复现例子。项目首页 | GitHub | 预印本框架使用的是:PyTorch1.4硬件用的是:GPU: 1*P100|CPU: 8核 64GB ...

这个 notebook 基于预印本论文「Castle in the Sky: Dynamic Sky Replacement and Harmonization in Videos, arXiv:2010.11800.」提供了最基本的视频天空替换的可复现例子。

项目首页 | GitHub | 预印本

框架使用的是:PyTorch1.4
硬件用的是:GPU: 1*P100|CPU: 8核 64GB

代码准备

import os
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/clf/code/SKYAR.zip
os.system('unzip SKYAR.zip')
!pip uninstall -y opencv-python 
!pip uninstall -y opencv-contrib-python
!pip install opencv-contrib-python
!pip install ipywidgets
!pip install aubio ffmpeg
%cd SKYAR/SkyAR
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import glob
import argparse
from networks import *
from skyboxengine import *
import utils
import torch

%matplotlib inline

# 检测运行设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

配置模型

parser = argparse.ArgumentParser(description='SKYAR')
args = utils.parse_config(path_to_json='./config/config-annarbor-castle.json')
# 如果你想尝试其他风格,请修改下面的参数

# args.net_G = "coord_resnet50"
args.ckptdir = "model"

# args.datadir = "./test_videos/annarbor.mp4" # 修改前景视频
# args.skybox = "floatingcastle.jpg" # 选择 skybox 模版

# args.in_size_w = 384 # 天空模型输入分辨率
# args.in_size_h = 384 # ...
# args.out_size_w = 845 # 视频输出分辨率
# args.out_size_h = 480 # ...

# args.skybox_center_crop = 0.5 # 虚拟摄像头视角
# args.auto_light_matching = False 
# args.relighting_factor = 0.8
# args.recoloring_factor = 0.5
# args.halo_effect = True

定义辅助功能

class SkyFilter():

    def __init__(self, args):

        self.ckptdir = args.ckptdir
        self.datadir = args.datadir
        self.input_mode = args.input_mode

        self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
        self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h

        self.skyboxengine = SkyBox(args)

        self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)
        self.load_model()

        self.video_writer = cv2.VideoWriter('demo.avi', cv2.VideoWriter_fourcc(*'MJPG'),
                                            20.0, (args.out_size_w, args.out_size_h))
        self.video_writer_cat = cv2.VideoWriter('demo-cat.avi', cv2.VideoWriter_fourcc(*'MJPG'),
                                            20.0, (2*args.out_size_w, args.out_size_h))

        if os.path.exists(args.output_dir) is False:
            os.mkdir(args.output_dir)

        self.output_img_list = []

        self.save_jpgs = args.save_jpgs


    def load_model(self):
        # 载入预训练天空模型
        print('loading the best checkpoint...')
        checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'))
        self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
        self.net_G.to(device)
        self.net_G.eval()


    def write_video(self, img_HD, syneth):

        frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
        self.video_writer.write(frame)

        frame_cat = np.concatenate([img_HD, syneth], axis=1)
        frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)
        self.video_writer_cat.write(frame_cat)

        # 定义结果缓冲区
        self.output_img_list.append(frame_cat)
        

    def synthesize(self, img_HD, img_HD_prev):

        h, w, c = img_HD.shape

        img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))

        img = np.array(img, dtype=np.float32)
        img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)

        with torch.no_grad():
            G_pred = self.net_G(img.to(device))
            G_pred = torch.nn.functional.interpolate(G_pred, (h, w), mode='bicubic', align_corners=False)
            G_pred = G_pred[0, :].permute([1, 2, 0])
            G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
            G_pred = np.array(G_pred.detach().cpu())
            G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)

        skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)

        syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)

        return syneth, G_pred, skymask



    def cvtcolor_and_resize(self, img_HD):

        img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
        img_HD = np.array(img_HD / 255., dtype=np.float32)
        img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))

        return img_HD
        

    def process_video(self):

        # 逐帧处理视频

        cap = cv2.VideoCapture(self.datadir)
        m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        img_HD_prev = None
        
        for idx in range(m_frames):
            ret, frame = cap.read()
            if ret:
                img_HD = self.cvtcolor_and_resize(frame)

                if img_HD_prev is None:
                    img_HD_prev = img_HD

                syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)

                self.write_video(img_HD, syneth)

                img_HD_prev = img_HD

                if idx % 50 == 1:
                  print('processing video, frame %d / %d ... ' % (idx, m_frames))

            else:  # if reach the last frame
                break

Now you can process your video

sf = SkyFilter(args)
sf.process_video()

查看结果

# 转换视频格式
!ffmpeg -i demo-cat.avi -vcodec libx264 -acodec aac demo-cat.mp4
from IPython.display import HTML
from base64 import b64encode
mp4 = open('demo-cat.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=600 controls>
      <source src="%s" type="video/mp4">
</video>
""" % data_url)

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。