天空背景转换
【摘要】 这个 notebook 基于预印本论文「Castle in the Sky: Dynamic Sky Replacement and Harmonization in Videos, arXiv:2010.11800.」提供了最基本的视频天空替换的可复现例子。项目首页 | GitHub | 预印本框架使用的是:PyTorch1.4硬件用的是:GPU: 1*P100|CPU: 8核 64GB ...
这个 notebook 基于预印本论文「Castle in the Sky: Dynamic Sky Replacement and Harmonization in Videos, arXiv:2010.11800.」提供了最基本的视频天空替换的可复现例子。
框架使用的是:PyTorch1.4
硬件用的是:GPU: 1*P100|CPU: 8核 64GB
代码准备
import os
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/clf/code/SKYAR.zip
os.system('unzip SKYAR.zip')
!pip uninstall -y opencv-python
!pip uninstall -y opencv-contrib-python
!pip install opencv-contrib-python
!pip install ipywidgets
!pip install aubio ffmpeg
%cd SKYAR/SkyAR
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os
import glob
import argparse
from networks import *
from skyboxengine import *
import utils
import torch
%matplotlib inline
# 检测运行设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
配置模型
parser = argparse.ArgumentParser(description='SKYAR')
args = utils.parse_config(path_to_json='./config/config-annarbor-castle.json')
# 如果你想尝试其他风格,请修改下面的参数
# args.net_G = "coord_resnet50"
args.ckptdir = "model"
# args.datadir = "./test_videos/annarbor.mp4" # 修改前景视频
# args.skybox = "floatingcastle.jpg" # 选择 skybox 模版
# args.in_size_w = 384 # 天空模型输入分辨率
# args.in_size_h = 384 # ...
# args.out_size_w = 845 # 视频输出分辨率
# args.out_size_h = 480 # ...
# args.skybox_center_crop = 0.5 # 虚拟摄像头视角
# args.auto_light_matching = False
# args.relighting_factor = 0.8
# args.recoloring_factor = 0.5
# args.halo_effect = True
定义辅助功能
class SkyFilter():
def __init__(self, args):
self.ckptdir = args.ckptdir
self.datadir = args.datadir
self.input_mode = args.input_mode
self.in_size_w, self.in_size_h = args.in_size_w, args.in_size_h
self.out_size_w, self.out_size_h = args.out_size_w, args.out_size_h
self.skyboxengine = SkyBox(args)
self.net_G = define_G(input_nc=3, output_nc=1, ngf=64, netG=args.net_G).to(device)
self.load_model()
self.video_writer = cv2.VideoWriter('demo.avi', cv2.VideoWriter_fourcc(*'MJPG'),
20.0, (args.out_size_w, args.out_size_h))
self.video_writer_cat = cv2.VideoWriter('demo-cat.avi', cv2.VideoWriter_fourcc(*'MJPG'),
20.0, (2*args.out_size_w, args.out_size_h))
if os.path.exists(args.output_dir) is False:
os.mkdir(args.output_dir)
self.output_img_list = []
self.save_jpgs = args.save_jpgs
def load_model(self):
# 载入预训练天空模型
print('loading the best checkpoint...')
checkpoint = torch.load(os.path.join(self.ckptdir, 'best_ckpt.pt'))
self.net_G.load_state_dict(checkpoint['model_G_state_dict'])
self.net_G.to(device)
self.net_G.eval()
def write_video(self, img_HD, syneth):
frame = np.array(255.0 * syneth[:, :, ::-1], dtype=np.uint8)
self.video_writer.write(frame)
frame_cat = np.concatenate([img_HD, syneth], axis=1)
frame_cat = np.array(255.0 * frame_cat[:, :, ::-1], dtype=np.uint8)
self.video_writer_cat.write(frame_cat)
# 定义结果缓冲区
self.output_img_list.append(frame_cat)
def synthesize(self, img_HD, img_HD_prev):
h, w, c = img_HD.shape
img = cv2.resize(img_HD, (self.in_size_w, self.in_size_h))
img = np.array(img, dtype=np.float32)
img = torch.tensor(img).permute([2, 0, 1]).unsqueeze(0)
with torch.no_grad():
G_pred = self.net_G(img.to(device))
G_pred = torch.nn.functional.interpolate(G_pred, (h, w), mode='bicubic', align_corners=False)
G_pred = G_pred[0, :].permute([1, 2, 0])
G_pred = torch.cat([G_pred, G_pred, G_pred], dim=-1)
G_pred = np.array(G_pred.detach().cpu())
G_pred = np.clip(G_pred, a_max=1.0, a_min=0.0)
skymask = self.skyboxengine.skymask_refinement(G_pred, img_HD)
syneth = self.skyboxengine.skyblend(img_HD, img_HD_prev, skymask)
return syneth, G_pred, skymask
def cvtcolor_and_resize(self, img_HD):
img_HD = cv2.cvtColor(img_HD, cv2.COLOR_BGR2RGB)
img_HD = np.array(img_HD / 255., dtype=np.float32)
img_HD = cv2.resize(img_HD, (self.out_size_w, self.out_size_h))
return img_HD
def process_video(self):
# 逐帧处理视频
cap = cv2.VideoCapture(self.datadir)
m_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
img_HD_prev = None
for idx in range(m_frames):
ret, frame = cap.read()
if ret:
img_HD = self.cvtcolor_and_resize(frame)
if img_HD_prev is None:
img_HD_prev = img_HD
syneth, G_pred, skymask = self.synthesize(img_HD, img_HD_prev)
self.write_video(img_HD, syneth)
img_HD_prev = img_HD
if idx % 50 == 1:
print('processing video, frame %d / %d ... ' % (idx, m_frames))
else: # if reach the last frame
break
Now you can process your video
sf = SkyFilter(args)
sf.process_video()
查看结果
# 转换视频格式
!ffmpeg -i demo-cat.avi -vcodec libx264 -acodec aac demo-cat.mp4
from IPython.display import HTML
from base64 import b64encode
mp4 = open('demo-cat.mp4','rb').read()
data_url = "data:video/mp4;base64," + b64encode(mp4).decode()
HTML("""
<video width=600 controls>
<source src="%s" type="video/mp4">
</video>
""" % data_url)
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)