他的回复:
训练代码如下所示import numpy as np import os import random import torch from data.datamgr import SetDataManager, SimpleDataManager from options import parse_args, get_resume_file, load_warmup_state from methods.LFTNet import LFTNet def cycle(iterable): while True: for x in iterable: yield x # training iterations def train(base_datamgr, base_set, aux_iter, val_loader, model, start_epoch, stop_epoch, params): # for validation max_acc = 0 total_it = 0 # training for epoch in range(start_epoch,stop_epoch): # randomly split seen domains to pseudo-seen and pseudo-unseen domains random_set = random.sample(base_set, k=2) ps_set = random_set[0] pu_set = random_set[1:] ps_loader = base_datamgr.get_data_loader(os.path.join(params.data_dir, ps_set, 'base.json'), aug=params.train_aug) pu_loader = base_datamgr.get_data_loader([os.path.join(params.data_dir, dataset, 'base.json') for dataset in pu_set], aug=params.train_aug) # train loop model.train() total_it = model.trainall_loop(epoch, ps_loader, pu_loader, aux_iter, total_it) # validate model.eval() with torch.no_grad(): acc = model.test_loop(val_loader) # save if acc > max_acc: print("best model! save...") max_acc = acc outfile = os.path.join(params.checkpoint_dir, 'best_model.tar') model.save(outfile, epoch) else: print('GG!! best accuracy {:f}'.format(max_acc)) if ((epoch + 1) % params.save_freq==0) or (epoch == stop_epoch - 1): outfile = os.path.join(params.checkpoint_dir, '{:d}.tar'.format(epoch + 1)) model.save(outfile, epoch) return # --- main function --- if __name__=='__main__': # set numpy random seed np.random.seed(10) # parse argument params = parse_args('train') print('--- LFTNet training: {} ---\n'.format(params.name)) print(params) # output and tensorboard dir params.tf_dir = '%s/log/%s'%(params.save_dir, params.name) params.checkpoint_dir = '%s/checkpoints/%s'%(params.save_dir, params.name) if not os.path.isdir(params.checkpoint_dir): os.makedirs(params.checkpoint_dir) # dataloader print('\n--- prepare dataloader ---') print(' train with multiple seen domains (unseen domain: {})'.format(params.testset)) datasets = ['miniImagenet', 'cars', 'places', 'cub', 'plantae'] datasets.remove(params.testset) val_file = os.path.join(params.data_dir, 'miniImagenet', 'val.json') # model print('\n--- build LFTNet model ---') if 'Conv' in params.model: image_size = 84 else: image_size = 224 n_query = max(1, int(16* params.test_n_way/params.train_n_way)) train_few_shot_params = dict(n_way = params.train_n_way, n_support = params.n_shot) base_datamgr = SetDataManager(image_size, n_query = n_query, **train_few_shot_params) aux_datamgr = SimpleDataManager(image_size, batch_size=16) aux_iter = iter(cycle(aux_datamgr.get_data_loader(os.path.join(params.data_dir, 'miniImagenet', 'base.json'), aug=params.train_aug))) test_few_shot_params = dict(n_way = params.test_n_way, n_support = params.n_shot) val_datamgr = SetDataManager(image_size, n_query = n_query, **test_few_shot_params) val_loader = val_datamgr.get_data_loader( val_file, aug = False) model = LFTNet(params, tf_path=params.tf_dir) model.cuda() # resume training start_epoch = params.start_epoch stop_epoch = params.stop_epoch if params.resume != '': resume_file = get_resume_file('%s/checkpoints/%s'%(params.save_dir, params.resume), params.resume_epoch) if resume_file is not None: start_epoch = model.resume(resume_file) print(' resume the training with at {} epoch (model file {})'.format(start_epoch, params.resume)) else: raise ValueError('No resume file') # load pre-trained feature encoder else: if params.warmup == 'gg3b0': raise Exception('Must provide pre-trained feature-encoder file using --warmup option!') model.model.feature.load_state_dict(load_warmup_state('%s/checkpoints/%s'%(params.save_dir, params.warmup), params.method), strict=False) # training print('\n--- start the training ---') train(base_datamgr, datasets, aux_iter, val_loader, model, start_epoch, stop_epoch, params)test代码如下所示import torch import os import h5py from methods import backbone from methods.backbone import model_dict from data.datamgr import SimpleDataManager from options import parse_args, get_best_file, get_assigned_file from methods.protonet import ProtoNet from methods.matchingnet import MatchingNet from methods.gnnnet import GnnNet from methods.relationnet import RelationNet import data.feature_loader as feat_loader import random import numpy as np # extract and save image features def save_features(model, data_loader, featurefile): f = h5py.File(featurefile, 'w') max_count = len(data_loader)*data_loader.batch_size all_labels = f.create_dataset('all_labels',(max_count,), dtype='i') all_feats=None count=0 for i, (x,y) in enumerate(data_loader): if (i % 10) == 0: print(' {:d}/{:d}'.format(i, len(data_loader))) x = x.cuda() feats = model(x) if all_feats is None: all_feats = f.create_dataset('all_feats', [max_count] + list( feats.size()[1:]) , dtype='f') all_feats[count:count+feats.size(0)] = feats.data.cpu().numpy() all_labels[count:count+feats.size(0)] = y.cpu().numpy() count = count + feats.size(0) count_var = f.create_dataset('count', (1,), dtype='i') count_var[0] = count f.close() # evaluate using features def feature_evaluation(cl_data_file, model, n_way = 5, n_support = 5, n_query = 15): class_list = cl_data_file.keys() select_class = random.sample(class_list,n_way) z_all = [] for cl in select_class: img_feat = cl_data_file[cl] perm_ids = np.random.permutation(len(img_feat)).tolist() z_all.append( [ np.squeeze( img_feat[perm_ids[i]]) for i in range(n_support+n_query) ] ) z_all = torch.from_numpy(np.array(z_all) ) model.n_query = n_query scores = model.set_forward(z_all, is_feature = True) pred = scores.data.cpu().numpy().argmax(axis = 1) y = np.repeat(range( n_way ), n_query ) acc = np.mean(pred == y)*100 return acc # --- main --- if __name__ == '__main__': # parse argument params = parse_args('test') print('Testing! {} shots on {} dataset with {} epochs of {}({})'.format(params.n_shot, params.dataset, params.save_epoch, params.name, params.method)) remove_featurefile = True print('\nStage 1: saving features') # dataset print(' build dataset') if 'Conv' in params.model: image_size = 84 else: image_size = 224 split = params.split loadfile = os.path.join(params.data_dir, params.dataset, split + '.json') datamgr = SimpleDataManager(image_size, batch_size = 64) data_loader = datamgr.get_data_loader(loadfile, aug = False) print(' build feature encoder') # feature encoder checkpoint_dir = '%s/checkpoints/%s'%(params.save_dir, params.name) if params.save_epoch != -1: modelfile = get_assigned_file(checkpoint_dir,params.save_epoch) else: modelfile = get_best_file(checkpoint_dir) if params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': model = backbone.Conv4NP() elif params.model == 'Conv6': model = backbone.Conv6NP() else: model = model_dict[params.model]( flatten = False ) else: model = model_dict[params.model]() model = model.cuda() tmp = torch.load(modelfile) try: state = tmp['state'] except KeyError: state = tmp['model_state'] except: raise state_keys = list(state.keys()) for i, key in enumerate(state_keys): if "feature." in key and not 'gamma' in key and not 'beta' in key: newkey = key.replace("feature.","") state[newkey] = state.pop(key) else: state.pop(key) model.load_state_dict(state) model.eval() # save feature file print(' extract and save features...') if params.save_epoch != -1: featurefile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + "_" + str(params.save_epoch)+ ".hdf5") else: featurefile = os.path.join( checkpoint_dir.replace("checkpoints","features"), split + ".hdf5") dirname = os.path.dirname(featurefile) if not os.path.isdir(dirname): os.makedirs(dirname) save_features(model, data_loader, featurefile) print('\nStage 2: evaluate') acc_all = [] iter_num = 1000 few_shot_params = dict(n_way = params.test_n_way , n_support = params.n_shot) # model print(' build metric-based model') if params.method == 'protonet': model = ProtoNet( model_dict[params.model], **few_shot_params) elif params.method == 'matchingnet': model = MatchingNet( model_dict[params.model], **few_shot_params ) elif params.method == 'gnnnet': model = GnnNet( model_dict[params.model], **few_shot_params) elif params.method in ['relationnet', 'relationnet_softmax']: if params.model == 'Conv4': feature_model = backbone.Conv4NP elif params.model == 'Conv6': feature_model = backbone.Conv6NP else: feature_model = model_dict[params.model] loss_type = 'mse' if params.method == 'relationnet' else 'softmax' model = RelationNet( feature_model, loss_type = loss_type , **few_shot_params ) else: raise ValueError('Unknown method') model = model.cuda() model.eval() # load model checkpoint_dir = '%s/checkpoints/%s'%(params.save_dir, params.name) if params.save_epoch != -1: modelfile = get_assigned_file(checkpoint_dir, params.save_epoch) else: modelfile = get_best_file(checkpoint_dir) if modelfile is not None: tmp = torch.load(modelfile) try: model.load_state_dict(tmp['state']) except RuntimeError: print('warning! RuntimeError when load_state_dict()!') model.load_state_dict(tmp['state'], strict=False) except KeyError: for k in tmp['model_state']: ##### revise latter if 'running' in k: tmp['model_state'][k] = tmp['model_state'][k].squeeze() model.load_state_dict(tmp['model_state'], strict=False) except: raise # load feature file print(' load saved feature file') cl_data_file = feat_loader.init_loader(featurefile) # start evaluate print(' evaluate') for i in range(iter_num): acc = feature_evaluation(cl_data_file, model, n_query=15, **few_shot_params) acc_all.append(acc) # statics print(' get statics') acc_all = np.asarray(acc_all) acc_mean = np.mean(acc_all) acc_std = np.std(acc_all) print(' %d test iterations: Acc = %4.2f%% +- %4.2f%%' %(iter_num, acc_mean, 1.96* acc_std/np.sqrt(iter_num))) # remove feature files [optional] if remove_featurefile: os.remove(featurefile)