基于MindSpore的MASS网络实现
从1950年Turing提出著名的“图灵测试”以来,让机器学会“听”和“说”,实现与人类间的无障碍交流成为人机交互领域的一大梦想。近年来随着深度学习技术的发展,自然语言处理领域也取得重要突破,发展成为人工智能领域的一个重要分支,涵盖语音识别、信息检索、信息提取、机器翻译、智能问答等众多应用场景。
图1. Transformer网络结构
"""
Transformer with encoder and decoder.
In Transformer, we define T = src_max_len, T' = tgt_max_len.
Args:
config (TransformerConfig): Model config.
is_training (bool): Whether is training.
use_one_hot_embeddings (bool): Whether use one-hot embedding.
Returns:
Tuple[Tensor], network outputs.
"""
def __init__ ( self ,
config: TransformerConfig ,
is_training: bool ,
use_one_hot_embeddings: bool = False,
use_positional_embedding: bool = True ):
super(Transformer , self ). __init__ ()
self .use_positional_embedding = use_positional_embedding
config = copy.deepcopy(config)
self .is_training = is_training
if not is_training:
config.hidden_dropout_prob = 0.0
config.attention_dropout_prob = 0.0
self .input_mask_from_dataset = config.input_mask_from_dataset
self .batch_size = config.batch_size
self .max_positions = config.seq_length
self .attn_embed_dim = config.hidden_size
self .num_layers = config.num_hidden_layers
self .word _embed_dim = config.hidden_size
self .last_idx = self .num_layers - 1
self .embedding_lookup = EmbeddingLookup(
vocab_size =config.vocab_size ,
embed_dim = self .word_embed_dim ,
use_one_hot_embeddings =use_one_hot_embeddings)
if self .use_positional_embedding:
self .positional_embedding = PositionalEmbedding(
embedding_size = self .word_embed_dim ,
max_position_embeddings =config.max_position_embeddings)
self .encoder = TransformerEncoder(
attn_embed_dim = self .attn_embed_dim ,
encoder_layers = self .num_layers ,
num_attn_heads =config.num_attention_heads ,
intermediate_size =config.intermediate_size ,
attention_dropout_prob =config.attention_dropout_prob ,
initializer_range =config.initializer_range ,
hidden_dropout_prob =config.hidden_dropout_prob ,
hidden_act =config.hidden_act ,
compute_type =config.compute_type)
self .decoder = TransformerDecoder(
attn_embed_dim = self .attn_embed_dim ,
decoder_layers = self .num_layers ,
num_attn_heads =config.num_attention_heads ,
intermediate_size =config.intermediate_size ,
attn_dropout_prob =config.attention_dropout_prob ,
initializer_range =config.initializer_range ,
dropout_prob =config.hidden_dropout_prob ,
hidden_act =config.hidden_act ,
compute_type =config.compute_type)
self .cast = P.Cast()
self .dtype = config.dtype
self .cast_compute_type = SaturateCast( dst_type =config.compute_type)
self .slice = P.StridedSlice()
self .dropout = nn.Dropout( keep_prob = 1 - config.hidden_dropout_prob)
self ._create_attention_mask_from_input_mask = CreateAttentionMaskFrom InputMask(config)
self .scale = Tensor([math.sqrt(float( self .word_embed_dim))] ,
dtype =mstype.float32)
self .multiply = P.Mul()
def construct ( self , source_ids , source_mask , target_ids , target_mask):
"""
Construct network.
In this method, T = src_max_len, T' = tgt_max_len.
Args:
source_ids (Tensor): Source sentences with shape (N, T).
source_mask (Tensor): Source sentences padding mask with shape (N, T),
where 0 indicates padding position.
target_ids (Tensor): Target sentences with shape (N, T').
target_mask (Tensor): Target sentences padding mask with shape (N, T'),
where 0 indicates padding position.
Returns:
Tuple[Tensor], network outputs.
"""
# Process source sentences.
src_embeddings , embedding_tables = self .embedding_lookup(source_ids)
src_embeddings = self .multiply(src_embeddings , self .scale)
if self .use_positional_embedding:
src_embeddings = self .positional_embedding(src_embeddings)
src_embeddings = self .dropout(src_embeddings)
# Attention mask with shape (N, T, T).
enc_attention_mask = self ._create_attention_mask_from_input_mask(source_mask)
# Transformer encoder.
encoder_output = self .encoder(
self .cast_compute_type(src_embeddings) , # (N, T, D).
self .cast_compute_type(enc_attention_mask) # (N, T, T).
)
# Process target sentences.
tgt_embeddings , _ = self .embedding_lookup(target_ids)
tgt_embeddings = self .multiply(tgt_embeddings , self .scale)
if self .use_positional_embedding:
tgt_embeddings = self .positional_embedding(tgt_embeddings)
tgt_embeddings = self .dropout(tgt_embeddings)
# Attention mask with shape (N, T', T').
tgt_attention_mask = self ._create_attention_mask_from_input_mask(
target_mask , True
)
# Transformer decoder.
decoder_output = self .decoder(
self .cast_compute_type(tgt_embeddings) , # (N, T', D)
self .cast_compute_type(tgt_attention_mask) , # (N, T', T')
encoder_output , # (N, T, D)
enc_attention_mask # (N, T, T)
)
return encoder_output , decoder_output , embedding_tables
"""
Transformer training network.
Args:
config (TransformerConfig): The config of Transformer.
is_training (bool): Specifies whether to use the training mode.
use_one_hot_embeddings (bool): Specifies whether to use one-hot for embeddings.
Returns:
Tensor, prediction_scores, seq_relationship_score.
"""
def __init__ ( self , config , is_training , use_one_hot_embeddings):
super(TransformerTraining , self ). __init__ ()
self .transformer = Transformer(config , is_training , use_one_hot_embeddings)
self .projection = PredLogProbs(config)
def construct ( self , source_ids , source_mask , target_ids , target_mask):
"""
Construct network.
Args:
source_ids (Tensor): Source sentence.
source_mask (Tensor): Source padding mask.
target_ids (Tensor): Target sentence.
target_mask (Tensor): Target padding mask.
Returns:
Tensor, prediction_scores, seq_relationship_score.
"""
_ , decoder_outputs , embedding_table =
self .transformer(source_ids , source_mask , target_ids , target_mask)
prediction_scores = self .projection(decoder_outputs ,
embedding_table)
return prediction_scores
from mindspore import context
context.set_context(mode=context.GRAPH_MODE,device_target="Ascend",
reserve_class_name_in_scope=False,device_id=device_id)
import mindspore.dataset.engine as de
ds=de.TFRecordDataset(input_files,columns_list=["source_eos_ids","source_eos_mask","target_sos_ids","target_sos_mask", "target_eos_ids","target_eos_mask"],shuffle=shuffle,num_shards=rank_size,shard_id=rank_id,shard_equal_rows=True,num_parallel_workers=8)
ori_dataset_size = ds.get_dataset_size()
print(f" | Dataset size:{ori_dataset_size}.")
repeat_count = epoch_count
ifsink_mode:
ds.set_dataset_size(sink_step * batch_size)
repeat_count = epoch_count * ori_dataset_size // ds.get_dataset_size()
type_cast_op = deC.TypeCast(mstype.int32)
ds = ds.map(input_columns="source_eos_ids",operations=type_cast_op)
ds = ds.map(input_columns="source_eos_mask",operations=type_cast_op)
ds = ds.map(input_columns="target_sos_ids",operations=type_cast_op)
ds = ds.map(input_columns="target_sos_mask",operations=type_cast_op)
ds = ds.map(input_columns="target_eos_ids",operations=type_cast_op)
ds = ds.map(input_columns="target_eos_mask",operations=type_cast_op)
ds = ds.batch(batch_size,drop_remainder=True)
ds = ds.repeat(repeat_count)
from mindspore.nn.optim import Adam
from mindspore.train.model import Model
from mindspore.train.callback import CheckpointConfig
from config import TransformerConfig
from src.transformer import TransformerNetworkWithLoss,TransformerTrainOneStepWithLossScaleCell
config = TransformerConfig.from_json_file(“config.json”)
net_with_loss = TransformerNetworkWithLoss(config,is_training=True)
net_with_loss.init_parameters_data()
lr = Tensor(polynomial_decay_scheduler(lr=config.lr,
min_lr=config.min_lr,
decay_steps=config.decay_steps,
total_update_num=update_steps,
warmup_steps=config.warmup_steps,
power=config.poly_lr_scheduler_power),
dtype=mstype.float32)
optimizer = Adam(net_with_loss.trainable_params(),lr,beta1=0.9,beta2=0.98)
scale_manager = DynamicLossScaleManager(init_loss_scale=config.init_loss_scale,scale_factor=config.loss_scale_factor,scale_window=config.scale_window)
net_with_grads = TransformerTrainOneStepWithLossScaleCell(
network=net_with_loss,
optimizer=optimizer,
scale_update_cell=scale_manager.get_update_cell())
net_with_grads.set_train(True)
model = Model(net_with_grads)
loss_monitor = LossCallBack(config)
ckpt_config = CheckpointConfig(save_checkpoint_steps=config.save_ckpt_steps, keep_checkpoint_max=config.keep_ckpt_max)
ckpt_callback = ModelCheckpoint(prefix=config.ckpt_prefix, directory=os.path.join(config.ckpt_path,'ckpt_{}'.format(os.getenv('DEVICE_ID'))), config=ckpt_config)
callbacks= [loss_monitor,ckpt_callback]
model.train(epoch_size,pre_training_dataset,callbacks=callbacks,dataset_sink_mode=config.dataset_sink_mode)
import pickle
frommindspore.trainimportModel
frommindspore.train.serializationimportload_checkpoint,load_param_into_net
fromconfigimportTransformerConfig
from.transformer_for_inferimportTransformerInferModel
fromsrc.utilsimportget_score
config=TransformerConfig(config.path)
tfm_model = TransformerInferModel(config=config,use_one_hot_embeddings=False)
tfm_model.init_parameters_data()
weights = load_infer_weights(checkpoint.path)
load_param_into_net(tfm_model, weights)
tfm_infer = TransformerInferCell(tfm_model)
model = Model(tfm_infer)
predictions = []
probs = []
source_sentences = []
target_sentences = []
forbatchindataset.create_dict_iterator():
source_sentences.append(batch["source_eos_ids"])
target_sentences.append(batch["target_eos_ids"])
source_ids = Tensor(batch["source_eos_ids"],mstype.int32)
source_mask = Tensor(batch["source_eos_mask"],mstype.int32)
start_time = time.time()
predicted_ids,entire_probs = model.predict(source_ids,source_mask)
print(f" | Batch size:{config.batch_size}, "
f"Time cost:{time.time() - start_time}.")
predictions.append(predicted_ids.asnumpy())
probs.append(entire_probs.asnumpy())
output = []
forinputs,ref,batch_out,batch_probsinzip(source_sentences,
target_sentences,
predictions,
probs):
foriinrange(config.batch_size):
ifbatch_out.ndim ==3:
batch_out = batch_out[:,0]
example = {"source": inputs[i].tolist(),
"target": ref[i].tolist(),
"prediction": batch_out[i].tolist(),
"prediction_prob":batch_probs[i].tolist()}
output.append(example)
score = get_score(output,vocab=args.vocab,metric=args.metric)
[1] 李舟军, 范宇, 吴贤杰. 面向自然语言处理的预训练技术研究综述[J].计算机科学. 2020.
[2] 从Word Embedding到Bert模型—自然语言处理中的预训练技术发展史. https://zhuanlan.zhihu.com/p/49271699
[3] Song K , Tan X , Qin T , et al. MASS: Masked Sequence to Sequence Pre-training for Language Generation[J]. 2019.
- 点赞
- 收藏
- 关注作者
评论(0)