基于ModelArts实现黑白图像上色
Instance-aware Image Colorization
Instance-aware Image Colorization 实例感知图像上色
Jheng-Wei Su,
Hung-Kuo Chu, and
Jia-Bin Huang
In IEEE Conference on Computer Vision and Pattern Recognition (CVPR), 2020.
图像着色本质上是一个具有多模态不确定性的不适定问题。以前的方法利用深度神经网络将输入灰度图像直接映射到合理的颜色输出。尽管这些基于学习的方法表现出令人印象深刻的性能,但它们通常在包含多个对象的输入图像上失败。主要原因是现有模型对整个图像进行学习和着色。在缺乏清晰的图形-背景分离的情况下,这些模型无法有效地定位和学习有意义的对象级语义。在论文中,提出了一种实现实例感知着色的方法。网络架构利用现成的对象检测器来获取裁剪的对象图像,并使用实例着色网络来提取对象级特征。论文使用类似的网络来提取全图像特征,并将融合模块应用于完整的对象级和图像级特征以预测最终颜色。着色网络和融合模块都是从大规模数据集中学习的。实验结果表明,该论文的工作在不同质量指标上优于现有方法,并在图像着色方面达到了最先进的性能。
下载代码和数据
import os
!wget https://obs-aigallery-zc.obs.cn-north-4.myhuaweicloud.com/clf/code/InstColorization.zip
os.system('unzip InstColorization.zip')
安装依赖库
import torch, torchvision
print(torch.__version__, torch.cuda.is_available())
!gcc --version
!pip install torch==1.5 torchvision==0.6
!pip install cython pyyaml==5.1
!pip install -U 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'
!pip install dominate==2.4.0
!pip install detectron2==0.1.3 -f https://dl.fbaipublicfiles.com/detectron2/wheels/cu101/index.html
准备上色
cd InstColorization/
/home/ma-user/work/InstColorization/InstColorization
配置 Detectron2
可能需要一段时间
from os.path import join, isfile, isdir
from os import listdir
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from argparse import ArgumentParser
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()
import numpy as np
import cv2
# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
import torch
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml"))
cfg.MODEL.ROI_HEADS.SCORE_THRESH_TEST = 0.7
cfg.MODEL.WEIGHTS = model_zoo.get_checkpoint_url("COCO-InstanceSegmentation/mask_rcnn_X_101_32x8d_FPN_3x.yaml")
predictor = DefaultPredictor(cfg)
Let’s create a bounding box folder to save our prediction results.
input_dir = "example"
image_list = [f for f in listdir(input_dir) if isfile(join(input_dir, f))]
output_npz_dir = "{0}_bbox".format(input_dir)
if os.path.isdir(output_npz_dir) is False:
print('Create path: {0}'.format(output_npz_dir))
os.makedirs(output_npz_dir)
Here we simply take L channel as our input and make sure that we can get consistent box prediction results even though the original image is color images.
for image_path in image_list:
img = cv2.imread(join(input_dir, image_path))
lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
l_channel, a_channel, b_channel = cv2.split(lab_image)
l_stack = np.stack([l_channel, l_channel, l_channel], axis=2)
outputs = predictor(l_stack)
save_path = join(output_npz_dir, image_path.split('.')[0])
pred_bbox = outputs["instances"].pred_boxes.to(torch.device('cpu')).tensor.numpy()
pred_scores = outputs["instances"].scores.cpu().data.numpy()
np.savez(save_path, bbox = pred_bbox, scores = pred_scores)
Now we have all the images’ prediction results.
!ls example_bbox
图像上色
We first set up some libraries and options
import sys
import time
from options.train_options import TestOptions
from models import create_model
import torch
from tqdm import tqdm_notebook
from fusion_dataset import Fusion_Testing_Dataset
from util import util
import multiprocessing
multiprocessing.set_start_method('spawn', True)
torch.backends.cudnn.benchmark = True
sys.argv = [sys.argv[0]]
opt = TestOptions().parse()
Then we need to create a results folder to save our predicted color images and read the dataset loader.
save_img_path = opt.results_img_dir
if os.path.isdir(save_img_path) is False:
print('Create path: {0}'.format(save_img_path))
os.makedirs(save_img_path)
opt.batch_size = 1
dataset = Fusion_Testing_Dataset(opt, -1)
dataset_loader = torch.utils.data.DataLoader(dataset, batch_size=opt.batch_size)
dataset_size = len(dataset)
print('#Testing images = %d' % dataset_size)
#Testing images = 8
Load the pre-trained model.
model = create_model(opt)
model.setup_to_test('coco_finetuned_mask_256_ffs')
Start to colorize every images in dataset_loader
.
count_empty = 0
for data_raw in tqdm_notebook(dataset_loader):
data_raw['full_img'][0] = data_raw['full_img'][0].cuda()
if data_raw['empty_box'][0] == 0:
data_raw['cropped_img'][0] = data_raw['cropped_img'][0].cuda()
box_info = data_raw['box_info'][0]
box_info_2x = data_raw['box_info_2x'][0]
box_info_4x = data_raw['box_info_4x'][0]
box_info_8x = data_raw['box_info_8x'][0]
cropped_data = util.get_colorization_data(data_raw['cropped_img'], opt, ab_thresh=0, p=opt.sample_p)
full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p)
model.set_input(cropped_data)
model.set_fusion_input(full_img_data, [box_info, box_info_2x, box_info_4x, box_info_8x])
model.forward()
else:
count_empty += 1
full_img_data = util.get_colorization_data(data_raw['full_img'], opt, ab_thresh=0, p=opt.sample_p)
model.set_forward_without_box(full_img_data)
model.save_current_imgs(join(save_img_path, data_raw['file_id'][0] + '.png'))
print('{0} images without bounding boxes'.format(count_empty))
展示上色结果
修改show_index可以使用不同图片
def imshow(img):
import IPython
import cv2
_, ret = cv2.imencode('.jpg', img)
i = IPython.display.Image(data=ret)
IPython.display.display(i)
img_name_list = ['000000022969', '000000023781', '000000046872', '000000050145']
# 修改对应索引使用不同图片 0-3
show_index = 1
img = cv2.imread('example/'+img_name_list[show_index]+'.jpg')
lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
l_channel, _, _ = cv2.split(lab_image)
img = cv2.imread('results/'+img_name_list[show_index]+'.png')
lab_image = cv2.cvtColor(img, cv2.COLOR_BGR2LAB)
_, a_pred, b_pred = cv2.split(lab_image)
a_pred = cv2.resize(a_pred, (l_channel.shape[1], l_channel.shape[0]))
b_pred = cv2.resize(b_pred, (l_channel.shape[1], l_channel.shape[0]))
gray_color = np.ones_like(a_pred) * 128
gray_image = cv2.cvtColor(np.stack([l_channel, gray_color, gray_color], 2), cv2.COLOR_LAB2BGR)
color_image = cv2.cvtColor(np.stack([l_channel, a_pred, b_pred], 2), cv2.COLOR_LAB2BGR)
# save_img_path = 'results_origin/'
# if os.path.isdir(save_img_path) is False:
# print('Create path: {0}'.format(save_img_path))
# os.makedirs(save_img_path)
# cv2.imwrite('results_origin/'+img_name_list[show_index]+'.png', color_image)
imshow(np.concatenate([gray_image, color_image], 1))
训练
环境要求
- CUDA 10.1
- Python3
- Pytorch >= 1.5
- Detectron2
- OpenCV-Python
- Pillow/scikit-image
- Please refer to the env.yml for detail dependencies.
准备工作
- 准备代码:
cd InstColorization
- 安装所有依赖
conda env create --file env.yml
- 切换conda环境
conda activate instacolorization
- 安装其他依赖
sh scripts/install.sh
数据准备
COCOStuff
- 下载解压 COCOStuff 训练集:
sh scripts/prepare_cocostuff.sh
- 此时训练集在 train_data.
你自己的数据
- 如果你使用自己的数据你需要修改 scripts/prepare_train_box.sh’s L1 和 scripts/train.sh’s L1.的 dataset path
预训练模型
- 从谷歌云盘下载 google drive.
本代码里已经下载完毕了
sh scripts/download_model.sh
- 预训练模型会在这里: checkpoints.
实例预测
请使用如下命令去预测所有标注框 ${DATASET_DIR}
folder.
sh scripts/prepare_train_box.sh
所有预测结果会保存在 ${DATASET_DIR}_bbox
文件夹.
训练图像上色模型
运行下面命令,训练会开始
sh scripts/train.sh
运行visdom -port 8098
并进入 http://localhost:8098. 查看训练结果和损失函数
This is a 3 stage training process.
- We would start to train our full image colorization branch based on the siggraph_retrained’s pretrained weight.
- We would use the full image colorization branch’s weight as our instance colorization branch’s pretrained weight.
- Finally, we would train the fusion module.
测试图片上色模型
-
Our model’s weight would place in checkpoints/coco_mask.
-
Change the checkpoint’s path in test_fusion.py’s L38 from
coco_finetuned_mask_256_ffs
tococo_mask
-
Please follow the command below to colorize all the images in
example
foler based on the weight placed incoco_mask
.python test_fusion.py --name test_fusion --sample_p 1.0 --model fusion --fineSize 256 --test_img_dir example --results_img_dir results
All the colorized results would save in
results
folder.
- 点赞
- 收藏
- 关注作者
评论(0)