基于ModelArts实现小样本学习
【摘要】 小样本学习 本baseline采用pytorch框架,应用ModelArts的Notebook进行开发 为该论文复现代码Cross-Domain Few-Shot Classification via Learned Feature-Wise TransformationHung-Yu Tseng, Hsin-Ying Lee, Jia-Bin Huang, Ming-Hsuan Yang...
小样本学习
本baseline采用pytorch框架,应用ModelArts的Notebook进行开发
为该论文复现代码
Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation
Hung-Yu Tseng, Hsin-Ying Lee, Jia-Bin Huang, Ming-Hsuan Yang
International Conference on Learning Representations (ICLR), 2020 (spotlight)
训练
使用了CUB_200_2011,miniImageNet,omniglot数据集
要使用其他数据集请参考这里
使用其他网络:
- 1.训练预训练特征编码器(PRETRAIN设置成baseline++ 或 baseline)
python train_baseline.py --method PRETRAIN --dataset miniImagenet --name PRETRAIN --train_aug
- 2.训练
可以用train.py或者train_baseline.py
参考github: https://github.com/hytseng0509/CrossDomainFewShot
获取代码和数据
需要下载一段时间
# 获取代码和数据
import moxing as mox
mox.file.copy_parallel('obs://ma-competitions-bj4/fewshot/baseline','baseline')
%cd baseline
/home/ma-user/work/baseline
导入依赖库
!pip install torch==1.5 torchvision==0.6
!pip install -r requirements.txt
import numpy as np
import random
import torch
from data.datamgr import SetDataManager, SimpleDataManager
from options import parse_args, get_resume_file, load_warmup_state
from methods.LFTNet import LFTNet
from methods.backbone import model_dict
from train import cycle,train
from PIL import Image
import torchvision.transforms as transforms
import moxing as mox
import argparse
import os
参数设置
script = 'train'
parser = argparse.ArgumentParser(description= 'few-shot script %s' %(script))
parser.add_argument('--dataset', default='multi', help='miniImagenet/cub/omniglot, specify multi for training with multiple domains')
parser.add_argument('--testset', default='cub', help='miniImagenet/cub/omniglot, valid only when dataset=multi')
parser.add_argument('--model', default='ResNet10', help='model: Conv{4|6} / ResNet{10|18|34}') # we use ResNet10
parser.add_argument('--method', default='baseline', help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/gnnnet')
parser.add_argument('--train_n_way' , default=5, type=int, help='class num to classify for training')
parser.add_argument('--test_n_way' , default=5, type=int, help='class num to classify for testing (validation) ')
parser.add_argument('--n_shot' , default=5, type=int, help='number of labeled data in each class, same as n_support')
parser.add_argument('--train_aug' , action='store_true', help='perform data augmentation or not during training ')
parser.add_argument('--name' , default='tmp', type=str, help='')
parser.add_argument('--save_dir' , default='./output', type=str, help='')
parser.add_argument('--data_dir' , default='./filelists', type=str, help='')
if script == 'train':
parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline')
parser.add_argument('--save_freq' , default=25, type=int, help='Save frequency')
parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch')
parser.add_argument('--stop_epoch' , default=400, type=int, help ='Stopping epoch')
parser.add_argument('--resume' , default='', type=str, help='continue from previous trained model with largest epoch')
parser.add_argument('--resume_epoch', default=50, type=int, help='')
parser.add_argument('--warmup' , default='gg3b0', type=str, help='continue from baseline, neglected if resume is true')
elif script == 'test':
parser.add_argument('--split' , default='novel', help='base/val/novel')
parser.add_argument('--save_epoch' , default=10, type=int, help ='load the model trained in x epoch, use the best model if x is -1')
params, unknown = parser.parse_known_args()
params.method = 'matchingnet'
params.name = 'multi_TESTSET_lft_METHOD'
params.warmup = 'baseline++'
params.resume = 'multi_TESTSET_lft_METHOD'
模型训练
由于数据量过大,要获取完整数据,请进入baseline/filelists/,运行data_download.ipynb
如果使用其他数据集请修改datasets
训练前,将下列代码取消注释
# # --- main function ---
# if __name__=='__main__':
# # set numpy random seed
# np.random.seed(10)
# # parse argument
# print('--- LFTNet training: {} ---\n'.format(params.name))
# print(params)
# # output and tensorboard dir
# params.tf_dir = '%s/log/%s'%(params.save_dir, params.name)
# params.checkpoint_dir = '%s/checkpoints/%s'%(params.save_dir, params.name)
# if not os.path.isdir(params.checkpoint_dir):
# os.makedirs(params.checkpoint_dir)
# # dataloader
# print('\n--- prepare dataloader ---')
# print('train with multiple seen domains (unseen domain: {})'.format(params.testset))
# datasets = ['miniImagenet', 'cub', 'omniglot'] #
# datasets.remove(params.testset)
# val_file = os.path.join(params.data_dir, 'miniImagenet', 'val.json')
# # model
# print('\n--- build LFTNet model ---')
# if 'Conv' in params.model:
# image_size = 84
# else:
# image_size = 224
# n_query = max(1, int(16* params.test_n_way/params.train_n_way))
# train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot)
# base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params)
# aux_datamgr = SimpleDataManager(image_size, batch_size=16)
# aux_iter = iter(cycle(aux_datamgr.get_data_loader(os.path.join(params.data_dir, 'miniImagenet', 'base.json'), aug=params.train_aug)))
# test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot)
# val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
# val_loader = val_datamgr.get_data_loader( val_file, aug = False)
# model = LFTNet(params, tf_path=params.tf_dir)
# model.cuda()
# # resume training
# start_epoch = params.start_epoch
# stop_epoch = params.stop_epoch
# if params.resume != '':
# print(params.save_dir, params.resume, params.resume_epoch)
# resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch)
# print(resume_file)
# if resume_file is not None:
# start_epoch = model.resume(resume_file)
# print(' resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))
# else:
# raise ValueError('No resume file')
# # load pre-trained feature encoder
# else:
# if params.warmup == 'gg3b0':
# raise Exception('Must provide pre-trained feature-encoder file using --warmup option!')
# model.model.feature.load_state_dict(load_warmup_state('%s/checkpoints/%s'%(params.save_dir, params.warmup), params.method), strict=False)
# # training
# print('\n--- start the training ---')
# train(base_datamgr, datasets, aux_iter, val_loader, model, start_epoch, stop_epoch, params)
model_path = 'model/best_model.pth'
mox.file.copy_parallel('./output/checkpoints/'+params.name+'/best_model.pth',model_path)
模型测试
feature即模型从图像中提取的特征向量
img_file = './filelists/test.jpg'
infer_transformation = transforms.Compose([
transforms.Resize([int(224*1.15), int(224*1.15)]),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
model = model_dict['ResNet10']()
tmp = torch.load(model_path, map_location ='cpu')
try:
state = tmp['state']
except KeyError:
state = tmp['model_state']
except:
raise
state_keys = list(state.keys())
for i, key in enumerate(state_keys):
if "feature." in key and not 'gamma' in key and not 'beta' in key:
newkey = key.replace("feature.","")
state[newkey] = state.pop(key)
else:
state.pop(key)
model.load_state_dict(state)
model.eval()
img = Image.open(img_file).convert('RGB')
img_tensor = torch.unsqueeze(infer_transformation(img) , 0)
feature = model(img_tensor)[0].detach().numpy().tolist()
print(len(feature))
512
模型导入ModelArts
from modelarts.session import Session
from modelarts.model import Model
from modelarts.config.model_config import TransformerConfig,Params
!pip install json5
import moxing as mox
import json5
import re
import traceback
import random
try:
session = Session()
config_path = 'model/config.json'
if mox.file.exists(config_path): # 判断一下是否存在配置文件,如果没有则不能导入模型
model_location = './model'
model_name = "Cross-Few-Shot"
load_dict = json5.loads(mox.file.read(config_path))
model_type = load_dict['model_type']
re_name = '_'+str(random.randint(0,1000))
model_name += re_name
print("正在导入模型,模型名称:", model_name)
model_instance = Model(
session,
model_name=model_name, # 模型名称
model_version="1.0.0", # 模型版本
source_location_type='LOCAL_SOURCE',
source_location=model_location, # 模型文件路径
model_type=model_type, # 模型类型
)
print("所有模型导入完成")
except Exception as e:
print("发生了一些问题,请看下面的报错信息:")
traceback.print_exc()
print("模型导入失败")
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)