基于ModelArts实现小样本学习

举报
HWCloudAI 发表于 2022/12/26 10:40:49 2022/12/26
【摘要】 小样本学习 本baseline采用pytorch框架,应用ModelArts的Notebook进行开发 为该论文复现代码Cross-Domain Few-Shot Classification via Learned Feature-Wise TransformationHung-Yu Tseng, Hsin-Ying Lee, Jia-Bin Huang, Ming-Hsuan Yang...

小样本学习

本baseline采用pytorch框架,应用ModelArts的Notebook进行开发

为该论文复现代码

Cross-Domain Few-Shot Classification via Learned Feature-Wise Transformation

Hung-Yu Tseng, Hsin-Ying Lee, Jia-Bin Huang, Ming-Hsuan Yang

International Conference on Learning Representations (ICLR), 2020 (spotlight)

训练

使用了CUB_200_2011,miniImageNet,omniglot数据集
要使用其他数据集请参考这里

使用其他网络:

  • 1.训练预训练特征编码器(PRETRAIN设置成baseline++ 或 baseline)
python train_baseline.py --method PRETRAIN --dataset miniImagenet --name PRETRAIN --train_aug
  • 2.训练
可以用train.py或者train_baseline.py

参考github: https://github.com/hytseng0509/CrossDomainFewShot

获取代码和数据

需要下载一段时间

# 获取代码和数据
import moxing as mox
mox.file.copy_parallel('obs://ma-competitions-bj4/fewshot/baseline','baseline')
%cd baseline
/home/ma-user/work/baseline

导入依赖库

!pip install torch==1.5 torchvision==0.6
!pip install -r requirements.txt

import numpy as np
import random
import torch
from data.datamgr import SetDataManager, SimpleDataManager
from options import parse_args, get_resume_file, load_warmup_state
from methods.LFTNet import LFTNet
from methods.backbone import model_dict
from train import cycle,train
from PIL import Image
import torchvision.transforms as transforms
import moxing as mox
import argparse
import os

参数设置

script = 'train'
parser = argparse.ArgumentParser(description= 'few-shot script %s' %(script))
parser.add_argument('--dataset', default='multi', help='miniImagenet/cub/omniglot, specify multi for training with multiple domains')
parser.add_argument('--testset', default='cub', help='miniImagenet/cub/omniglot, valid only when dataset=multi')
parser.add_argument('--model', default='ResNet10', help='model: Conv{4|6} / ResNet{10|18|34}') # we use ResNet10
parser.add_argument('--method', default='baseline',   help='baseline/baseline++/protonet/matchingnet/relationnet{_softmax}/gnnnet')
parser.add_argument('--train_n_way' , default=5, type=int,  help='class num to classify for training')
parser.add_argument('--test_n_way'  , default=5, type=int,  help='class num to classify for testing (validation) ')
parser.add_argument('--n_shot'      , default=5, type=int,  help='number of labeled data in each class, same as n_support')
parser.add_argument('--train_aug'   , action='store_true',  help='perform data augmentation or not during training ')
parser.add_argument('--name'        , default='tmp', type=str, help='')
parser.add_argument('--save_dir'    , default='./output', type=str, help='')
parser.add_argument('--data_dir'    , default='./filelists', type=str, help='')

if script == 'train':
    parser.add_argument('--num_classes' , default=200, type=int, help='total number of classes in softmax, only used in baseline')
    parser.add_argument('--save_freq'   , default=25, type=int, help='Save frequency')
    parser.add_argument('--start_epoch' , default=0, type=int,help ='Starting epoch')
    parser.add_argument('--stop_epoch'  , default=400, type=int, help ='Stopping epoch')
    parser.add_argument('--resume'      , default='', type=str, help='continue from previous trained model with largest epoch')
    parser.add_argument('--resume_epoch', default=50, type=int, help='')
    parser.add_argument('--warmup'      , default='gg3b0', type=str, help='continue from baseline, neglected if resume is true')
elif script == 'test':
    parser.add_argument('--split'       , default='novel', help='base/val/novel')
    parser.add_argument('--save_epoch'  , default=10, type=int, help ='load the model trained in x epoch, use the best model if x is -1')

params, unknown = parser.parse_known_args()

params.method = 'matchingnet'
params.name = 'multi_TESTSET_lft_METHOD'
params.warmup = 'baseline++'
params.resume = 'multi_TESTSET_lft_METHOD'

模型训练

由于数据量过大,要获取完整数据,请进入baseline/filelists/,运行data_download.ipynb

如果使用其他数据集请修改datasets

训练前,将下列代码取消注释


# # --- main function ---
# if __name__=='__main__':

#   # set numpy random seed
#   np.random.seed(10)

#   # parse argument
#   print('--- LFTNet training: {} ---\n'.format(params.name))
#   print(params)

#   # output and tensorboard dir
#   params.tf_dir = '%s/log/%s'%(params.save_dir, params.name)
#   params.checkpoint_dir = '%s/checkpoints/%s'%(params.save_dir, params.name)
#   if not os.path.isdir(params.checkpoint_dir):
#     os.makedirs(params.checkpoint_dir)

#   # dataloader
#   print('\n--- prepare dataloader ---')
#   print('train with multiple seen domains (unseen domain: {})'.format(params.testset))
#   datasets = ['miniImagenet', 'cub', 'omniglot']  #  
#   datasets.remove(params.testset)
#   val_file = os.path.join(params.data_dir, 'miniImagenet', 'val.json')

#   # model
#   print('\n--- build LFTNet model ---')
#   if 'Conv' in params.model:
#     image_size = 84
#   else:
#     image_size = 224

#   n_query = max(1, int(16* params.test_n_way/params.train_n_way))
#   train_few_shot_params   = dict(n_way = params.train_n_way, n_support = params.n_shot)
#   base_datamgr            = SetDataManager(image_size, n_query = n_query,  **train_few_shot_params)
#   aux_datamgr             = SimpleDataManager(image_size, batch_size=16)
#   aux_iter              = iter(cycle(aux_datamgr.get_data_loader(os.path.join(params.data_dir, 'miniImagenet', 'base.json'), aug=params.train_aug)))
#   test_few_shot_params    = dict(n_way = params.test_n_way, n_support = params.n_shot)
#   val_datamgr             = SetDataManager(image_size, n_query = n_query, **test_few_shot_params)
#   val_loader              = val_datamgr.get_data_loader( val_file, aug = False)

#   model = LFTNet(params, tf_path=params.tf_dir)
#   model.cuda()

#   # resume training
#   start_epoch = params.start_epoch
#   stop_epoch = params.stop_epoch
#   if params.resume != '':
#     print(params.save_dir, params.resume, params.resume_epoch)
#     resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch)
#     print(resume_file)
#     if resume_file is not None:
#       start_epoch = model.resume(resume_file)
#       print('  resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume))
#     else:
#       raise ValueError('No resume file')
#   # load pre-trained feature encoder
#   else:
#     if params.warmup == 'gg3b0':
#       raise Exception('Must provide pre-trained feature-encoder file using --warmup option!')
#     model.model.feature.load_state_dict(load_warmup_state('%s/checkpoints/%s'%(params.save_dir, params.warmup), params.method), strict=False)

#   # training
#   print('\n--- start the training ---')
#   train(base_datamgr, datasets, aux_iter, val_loader, model, start_epoch, stop_epoch, params)
model_path = 'model/best_model.pth'
mox.file.copy_parallel('./output/checkpoints/'+params.name+'/best_model.pth',model_path)

模型测试

feature即模型从图像中提取的特征向量

img_file = './filelists/test.jpg'

infer_transformation = transforms.Compose([
    transforms.Resize([int(224*1.15), int(224*1.15)]),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

model = model_dict['ResNet10']()
tmp = torch.load(model_path, map_location ='cpu')
try:
    state = tmp['state']
except KeyError:
    state = tmp['model_state']
except:
    raise
state_keys = list(state.keys())
for i, key in enumerate(state_keys):
    if "feature." in key and not 'gamma' in key and not 'beta' in key:
        newkey = key.replace("feature.","")
        state[newkey] = state.pop(key)
    else:
        state.pop(key)

model.load_state_dict(state)

model.eval()

img = Image.open(img_file).convert('RGB')
img_tensor =  torch.unsqueeze(infer_transformation(img) , 0)

feature = model(img_tensor)[0].detach().numpy().tolist()

print(len(feature))
512

模型导入ModelArts

from modelarts.session import Session
from modelarts.model import Model
from modelarts.config.model_config import TransformerConfig,Params
!pip install json5
import moxing as mox
import json5
import re
import traceback
import random

try:
    session = Session()
    config_path = 'model/config.json' 
    if mox.file.exists(config_path):                                        # 判断一下是否存在配置文件,如果没有则不能导入模型
        model_location =  './model'
        model_name = "Cross-Few-Shot"
        load_dict = json5.loads(mox.file.read(config_path))
        model_type = load_dict['model_type']
        re_name = '_'+str(random.randint(0,1000))
        model_name += re_name
        print("正在导入模型,模型名称:", model_name)
        model_instance = Model(
                     session, 
                     model_name=model_name,               # 模型名称
                     model_version="1.0.0",               # 模型版本
                      source_location_type='LOCAL_SOURCE',
                     source_location=model_location,      # 模型文件路径
                     model_type=model_type,               # 模型类型
                     )

    print("所有模型导入完成")
except Exception as e:
    print("发生了一些问题,请看下面的报错信息:") 
    traceback.print_exc()
    print("模型导入失败")
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。