DFCNN + Transformer模型完成中文语音识别(三)

举报
HWCloudAI 发表于 2022/12/19 11:41:11 2022/12/19
【摘要】 10.语言模型训练准备训练参数及数据def language_model_hparams(): params = HParams( num_heads = 8, num_blocks = 6, input_vocab_size = 50, label_vocab_size = 50, max_length = 10...

10.语言模型训练

准备训练参数及数据

def language_model_hparams():
    params = HParams(
        num_heads = 8,
        num_blocks = 6,
        input_vocab_size = 50,
        label_vocab_size = 50,
        max_length = 100,
        hidden_units = 512,
        dropout_rate = 0.2,
        learning_rate = 0.0003,
        is_training = True)
    return params

language_model_args = language_model_hparams()
language_model_args.input_vocab_size = len(train_data.pin_vocab)
language_model_args.label_vocab_size = len(train_data.han_vocab)
language = language_model(language_model_args)

print('语言模型参数:')
print(language_model_args)
WARNING:tensorflow:From <ipython-input-13-89328ea9c47f>:23: dropout (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.

Instructions for updating:

Use keras.layers.dropout instead.

WARNING:tensorflow:From <ipython-input-10-c0da1b61dd9f>:15: dense (from tensorflow.python.layers.core) is deprecated and will be removed in a future version.

Instructions for updating:

Use keras.layers.dense instead.

WARNING:tensorflow:From <ipython-input-11-0657c4b451a0>:8: conv1d (from tensorflow.python.layers.convolutional) is deprecated and will be removed in a future version.

Instructions for updating:

Use keras.layers.conv1d instead.

WARNING:tensorflow:From <ipython-input-13-89328ea9c47f>:40: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.

Instructions for updating:

Use tf.cast instead.


语言模型参数:

[('dropout_rate', 0.2), ('hidden_units', 512), ('input_vocab_size', 353), ('is_training', True), ('label_vocab_size', 415), ('learning_rate', 0.0003), ('max_length', 100), ('num_blocks', 6), ('num_heads', 8)]

训练语言模型

epochs = 20
print("训练轮数epochs:",epochs)

print("\n开始训练!")
with language.graph.as_default():
    saver =tf.train.Saver()
with tf.Session(graph=language.graph) as sess:
    merged = tf.summary.merge_all()
    sess.run(tf.global_variables_initializer())
    if os.path.exists('/speech_recognition/language_model/model.meta'):
        print('加载语言模型')
        saver.restore(sess, './speech_recognition/language_model/model')
    writer = tf.summary.FileWriter('./speech_recognition/language_model/tensorboard', tf.get_default_graph())
    for k in range(epochs):
        total_loss = 0
        batch = train_data.get_language_model_batch()
        for i in range(batch_num):
            input_batch, label_batch = next(batch)
            feed = {language.x: input_batch, language.y: label_batch}
            cost,_ = sess.run([language.mean_loss,language.train_op], feed_dict=feed)
            total_loss += cost
            if (k * batch_num + i) % 10 == 0:
                rs=sess.run(merged, feed_dict=feed)
                writer.add_summary(rs, k * batch_num + i)
        print('第', k+1, '个 epoch', ': average loss = ', total_loss/batch_num)
    print("\n训练完成,保存模型")
    saver.save(sess, './speech_recognition/language_model/model')
    writer.close()
训练轮数epochs: 20



开始训练!

第 1 个 epoch : average loss =  6.278316140174866

第 2 个 epoch : average loss =  3.2773803949356077

第 3 个 epoch : average loss =  1.9082921266555786

第 4 个 epoch : average loss =  1.3021608710289

第 5 个 epoch : average loss =  1.1006220042705537

第 6 个 epoch : average loss =  1.07493497133255

第 7 个 epoch : average loss =  1.0841107785701751

第 8 个 epoch : average loss =  1.0857421994209289

第 9 个 epoch : average loss =  1.0827948033809662

第 10 个 epoch : average loss =  1.1036933809518814

第 11 个 epoch : average loss =  1.102576231956482

第 12 个 epoch : average loss =  1.083077570796013

第 13 个 epoch : average loss =  1.0762047052383423

第 14 个 epoch : average loss =  1.0929352223873139

第 15 个 epoch : average loss =  1.0970239818096161

第 16 个 epoch : average loss =  1.0907334685325623

第 17 个 epoch : average loss =  1.0823354721069336

第 18 个 epoch : average loss =  1.0856867730617523

第 19 个 epoch : average loss =  1.0625348567962647

第 20 个 epoch : average loss =  1.0500436395406723



训练完成,保存模型

11.模型测试

准备解码所需字典,需和训练一致,也可以将字典保存到本地,直接进行读取。

data_args = data_hparams()
data_args.data_length = 20 # 重新训练需要注释该行
train_data = get_data(data_args)

准备测试所需数据, 不必和训练数据一致。

在本实践中,由于教学原因演示数据集及模型参数均较小,故不区分训练集和测试集。

test_data = get_data(data_args)
acoustic_model_batch = test_data.get_acoustic_model_batch()
language_model_batch = test_data.get_language_model_batch()

加载训练好的声学模型

acoustic_model_args = acoustic_model_hparams()
acoustic_model_args.vocab_size = len(train_data.acoustic_vocab)
acoustic = acoustic_model(acoustic_model_args)
acoustic.ctc_model.summary()
acoustic.ctc_model.load_weights('./speech_recognition/acoustic_model/model.h5')

print('声学模型参数:')
print(acoustic_model_args)
print('\n加载声学模型完成!')
__________________________________________________________________________________________________

Layer (type)                    Output Shape         Param #     Connected to                     

==================================================================================================

the_inputs (InputLayer)         (None, None, 200, 1) 0                                            

__________________________________________________________________________________________________

conv2d_21 (Conv2D)              (None, None, 200, 32 320         the_inputs[0][0]                 

__________________________________________________________________________________________________

batch_normalization_21 (BatchNo (None, None, 200, 32 128         conv2d_21[0][0]                  

__________________________________________________________________________________________________

conv2d_22 (Conv2D)              (None, None, 200, 32 9248        batch_normalization_21[0][0]     

__________________________________________________________________________________________________

batch_normalization_22 (BatchNo (None, None, 200, 32 128         conv2d_22[0][0]                  

__________________________________________________________________________________________________

max_pooling2d_7 (MaxPooling2D)  (None, None, 100, 32 0           batch_normalization_22[0][0]     

__________________________________________________________________________________________________

conv2d_23 (Conv2D)              (None, None, 100, 64 18496       max_pooling2d_7[0][0]            

__________________________________________________________________________________________________

batch_normalization_23 (BatchNo (None, None, 100, 64 256         conv2d_23[0][0]                  

__________________________________________________________________________________________________

conv2d_24 (Conv2D)              (None, None, 100, 64 36928       batch_normalization_23[0][0]     

__________________________________________________________________________________________________

batch_normalization_24 (BatchNo (None, None, 100, 64 256         conv2d_24[0][0]                  

__________________________________________________________________________________________________

max_pooling2d_8 (MaxPooling2D)  (None, None, 50, 64) 0           batch_normalization_24[0][0]     

__________________________________________________________________________________________________

conv2d_25 (Conv2D)              (None, None, 50, 128 73856       max_pooling2d_8[0][0]            

__________________________________________________________________________________________________

batch_normalization_25 (BatchNo (None, None, 50, 128 512         conv2d_25[0][0]                  

__________________________________________________________________________________________________

conv2d_26 (Conv2D)              (None, None, 50, 128 147584      batch_normalization_25[0][0]     

__________________________________________________________________________________________________

batch_normalization_26 (BatchNo (None, None, 50, 128 512         conv2d_26[0][0]                  

__________________________________________________________________________________________________

max_pooling2d_9 (MaxPooling2D)  (None, None, 25, 128 0           batch_normalization_26[0][0]     

__________________________________________________________________________________________________

conv2d_27 (Conv2D)              (None, None, 25, 128 147584      max_pooling2d_9[0][0]            

__________________________________________________________________________________________________

batch_normalization_27 (BatchNo (None, None, 25, 128 512         conv2d_27[0][0]                  

__________________________________________________________________________________________________

conv2d_28 (Conv2D)              (None, None, 25, 128 147584      batch_normalization_27[0][0]     

__________________________________________________________________________________________________

batch_normalization_28 (BatchNo (None, None, 25, 128 512         conv2d_28[0][0]                  

__________________________________________________________________________________________________

conv2d_29 (Conv2D)              (None, None, 25, 128 147584      batch_normalization_28[0][0]     

__________________________________________________________________________________________________

batch_normalization_29 (BatchNo (None, None, 25, 128 512         conv2d_29[0][0]                  

__________________________________________________________________________________________________

conv2d_30 (Conv2D)              (None, None, 25, 128 147584      batch_normalization_29[0][0]     

__________________________________________________________________________________________________

batch_normalization_30 (BatchNo (None, None, 25, 128 512         conv2d_30[0][0]                  

__________________________________________________________________________________________________

reshape_3 (Reshape)             (None, None, 3200)   0           batch_normalization_30[0][0]     

__________________________________________________________________________________________________

dropout_5 (Dropout)             (None, None, 3200)   0           reshape_3[0][0]                  

__________________________________________________________________________________________________

dense_5 (Dense)                 (None, None, 256)    819456      dropout_5[0][0]                  

__________________________________________________________________________________________________

dropout_6 (Dropout)             (None, None, 256)    0           dense_5[0][0]                    

__________________________________________________________________________________________________

the_labels (InputLayer)         (None, None)         0                                            

__________________________________________________________________________________________________

dense_6 (Dense)                 (None, None, 353)    90721       dropout_6[0][0]                  

__________________________________________________________________________________________________

input_length (InputLayer)       (None, 1)            0                                            

__________________________________________________________________________________________________

label_length (InputLayer)       (None, 1)            0                                            

__________________________________________________________________________________________________

ctc (Lambda)                    (None, 1)            0           the_labels[0][0]                 

                                                                 dense_6[0][0]                    

                                                                 input_length[0][0]               

                                                                 label_length[0][0]               

==================================================================================================

Total params: 1,790,785

Trainable params: 1,788,865

Non-trainable params: 1,920

__________________________________________________________________________________________________

声学模型参数:

[('is_training', True), ('learning_rate', 0.0008), ('vocab_size', 353)]



加载声学模型完成!

加载训练好的语言模型

language_model_args = language_model_hparams()
language_model_args.input_vocab_size = len(train_data.pin_vocab)
language_model_args.label_vocab_size = len(train_data.han_vocab)
language = language_model(language_model_args)
sess = tf.Session(graph=language.graph)
with language.graph.as_default():
    saver =tf.train.Saver()
with sess.as_default():
    saver.restore(sess, './speech_recognition/language_model/model')

print('语言模型参数:')
print(language_model_args)
print('\n加载语言模型完成!')
WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/tensorflow/python/training/saver.py:1266: checkpoint_exists (from tensorflow.python.training.checkpoint_management) is deprecated and will be removed in a future version.

Instructions for updating:

Use standard file APIs to check for files with this prefix.


语言模型参数:

[('dropout_rate', 0.2), ('hidden_units', 512), ('input_vocab_size', 353), ('is_training', True), ('label_vocab_size', 415), ('learning_rate', 0.0003), ('max_length', 100), ('num_blocks', 6), ('num_heads', 8)]



加载语言模型完成!

定义解码器

def decode_ctc(num_result, num2word):
    result = num_result[:, :, :]
    in_len = np.zeros((1), dtype = np.int32)
    in_len[0] = result.shape[1]
    t = K.ctc_decode(result, in_len, greedy = True, beam_width=10, top_paths=1)
    v = K.get_value(t[0][0])
    v = v[0]
    text = []
    for i in v:
        text.append(num2word[i])
    return v, text

使用搭建好的语音识别系统进行测试

在这里显示出10条语音示例的原文拼音及识别结果、原文汉字及识别结果。

for i in range(10):
    print('\n示例', i+1)
    # 载入训练好的模型,并进行识别
    inputs, outputs = next(acoustic_model_batch)
    x = inputs['the_inputs']
    y = inputs['the_labels'][0]
    result = acoustic.model.predict(x, steps=1)
    # 将数字结果转化为文本结果
    _, text = decode_ctc(result, train_data.acoustic_vocab)
    text = ' '.join(text)
    print('原文拼音:', ' '.join([train_data.acoustic_vocab[int(i)] for i in y]))
    print('识别结果:', text)
    with sess.as_default():
        try:
            _, y = next(language_model_batch)
        
            text = text.strip('\n').split(' ')
            x = np.array([train_data.pin_vocab.index(pin) for pin in text])
            x = x.reshape(1, -1)
            preds = sess.run(language.preds, {language.x: x})
            got = ''.join(train_data.han_vocab[idx] for idx in preds[0])
            print('原文汉字:', ''.join(train_data.han_vocab[idx] for idx in y[0]))
            print('识别结果:', got)
        except StopIteration:
            break
sess.close()
示例 1


WARNING:tensorflow:From /home/ma-user/anaconda3/envs/TensorFlow-1.13.1/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py:4303: sparse_to_dense (from tensorflow.python.ops.sparse_ops) is deprecated and will be removed in a future version.

Instructions for updating:

Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.


原文拼音: lv4 shi4 yang2 chun1 yan1 jing3 da4 kuai4 wen2 zhang1 de di3 se4 si4 yue4 de lin2 luan2 geng4 shi4 lv4 de2 xian1 huo2 xiu4 mei4 shi1 yi4 ang4 ran2

识别结果: lv4 shi4 yang2 chun1 yan1 jing3 da4 kuai4 wen2 zhang1 de di3 se4 si4 yue4 de lin2 luan2 geng4 shi4 lv4 de2 xian1 huo2 xiu4 mei4 shi1 yi4 ang4 ran2

原文汉字: 绿是阳春烟景大块文章的底色四月的林峦更是绿得鲜活秀媚诗意盎然

识别结果: 绿是阳春烟景大块文章的底色四月的林峦更是绿得鲜活秀媚诗意盎然



示例 2

原文拼音: ta1 jin3 ping2 yao1 bu4 de li4 liang4 zai4 yong3 dao4 shang4 xia4 fan1 teng2 yong3 dong4 she2 xing2 zhuang4 ru2 hai3 tun2 yi4 zhi2 yi3 yi1 tou2 de you1 shi4 ling3 xian1

识别结果: ta1 jin3 ping2 yao1 bu4 de li4 liang4 zai4 yong3 dao4 shang4 xia4 fan1 teng2 yong3 dong4 she2 xing2 zhuang4 ru2 hai3 tun2 yi4 zhi2 yi3 yi1 tou2 de you1 shi4 ling3 xian1

原文汉字: 他仅凭腰部的力量在泳道上下翻腾蛹动蛇行状如海豚一直以一头的优势领先

识别结果: 他仅凭腰部的力量在蛹道上下翻腾蛹动蛇行状如海豚一直以一头的优势领先



示例 3

原文拼音: pao4 yan3 da3 hao3 le zha4 yao4 zen3 me zhuang1 yue4 zheng4 cai2 yao3 le yao3 ya2 shu1 di4 tuo1 qu4 yi1 fu2 guang1 bang3 zi chong1 jin4 le shui3 cuan4 dong4

识别结果: pao4 yan3 da3 hao3 le zha4 yao4 zen3 me zhuang1 yue4 zheng4 cai2 yao3 le yao3 ya2 shu1 di4 tuo1 qu4 yi1 fu2 guang1 bang3 zi chong1 jin4 le shui3 cuan4 dong4

原文汉字: 炮眼打好了炸药怎么装岳正才咬了咬牙倏地脱去衣服光膀子冲进了水窜洞

识别结果: 炮眼打好了炸药怎么装岳正才咬了咬牙倏地脱去衣服光膀子冲进了水窜洞



示例 4

原文拼音: ke3 shei2 zhi1 wen2 wan2 hou4 ta1 yi1 zhao4 jing4 zi zhi1 jian4 zuo3 xia4 yan3 jian3 de xian4 you4 cu1 you4 hei1 yu3 you4 ce4 ming2 xian3 bu2 dui4 cheng1

识别结果: ke3 shei2 zhi1 wen2 wan2 hou4 ta1 yi1 zhao4 jing4 zi zhi1 jian4 zuo3 xia4 yan3 jian3 de xian4 you4 cu1 you4 hei1 yu3 you4 ce4 ming2 xian3 bu2 dui4 cheng1

原文汉字: 可谁知纹完后她一照镜子只见左下眼睑的线又粗又黑与右侧明显不对称

识别结果: 可谁只纹完后她一照镜子只见左下眼睑的线右粗右黑与右侧明显不对称



示例 5

原文拼音: yi1 jin4 men2 wo3 bei4 jing1 dai1 le zhe4 hu4 ming2 jiao4 pang2 ji2 de lao3 nong2 shi4 kang4 mei3 yuan2 chao2 fu4 shang1 hui2 xiang1 de lao3 bing1 qi1 zi3 chang2 nian2 you3 bing4 jia1 tu2 si4 bi4 yi1 pin2 ru2 xi3

识别结果: yi1 jin4 men2 wo3 bei4 jing1 dai1 le zhe4 hu4 ming2 jiao4 pang2 ji2 de lao3 nong2 shi4 kang4 mei3 yuan2 chao2 fu4 shang1 hui2 xiang1 de lao3 bing1 qi1 zi3 chang2 nian2 you3 bing4 jia1 tu2 si4 bi4 yi1 pin2 ru2 xi3

原文汉字: 一进门我被惊呆了这户名叫庞吉的老农是抗美援朝负伤回乡的老兵妻子长年有病家徒四壁一贫如洗

识别结果: 一进门我被惊呆了这户名叫庞吉的老农是抗美援朝负伤回乡的老兵妻子长年有病家徒四壁一贫如洗



示例 6

原文拼音: zou3 chu1 cun1 zi lao3 yuan3 lao3 yuan3 wo3 hai2 hui2 tou2 zhang1 wang4 na4 ge4 an1 ning2 tian2 jing4 de xiao3 yuan4 na4 ge4 shi3 wo3 zhong1 shen1 nan2 wang4 de xiao3 yuan4

识别结果: zou3 chu1 cun1 zi lao3 yuan3 lao3 yuan3 wo3 hai2 hui2 tou2 zhang1 wang4 na4 ge4 an1 ning2 tian2 jing4 de xiao3 yuan4 na4 ge4 shi3 wo3 zhong1 shen1 nan2 wang4 de xiao3 yuan4

原文汉字: 走出村子老远老远我还回头张望那个安宁恬静的小院那个使我终身难忘的小院

识别结果: 走出村子老远老远我还回头张望那个安宁恬静的小院那个使我终身难望的小院



示例 7

原文拼音: er4 yue4 si4 ri4 zhu4 jin4 xin1 xi1 men2 wai4 luo2 jia1 nian3 wang2 jia1 gang1 zhu1 zi4 qing1 wen2 xun4 te4 di4 cong2 dong1 men2 wai4 gan3 lai2 qing4 he4

识别结果: er4 yue4 si4 ri4 zhu4 jin4 xin1 xi1 men2 wai4 luo2 jia1 nian3 wang2 jia1 gang1 zhu1 zi4 qing1 wen2 xun4 te4 di4 cong2 dong1 men2 wai4 gan3 lai2 qing4 he4

原文汉字: 二月四日住进新西门外罗家碾王家冈朱自清闻讯特地从东门外赶来庆贺

识别结果: 二月四日住进新西门外罗家碾王家冈朱自清闻讯特地从东门外赶来庆贺



示例 8

原文拼音: dan1 wei4 bu2 shi4 wo3 lao3 die1 kai1 de ping2 shen2 me yao4 yi1 ci4 er4 ci4 zhao4 gu4 wo3 wo3 bu4 neng2 ba3 zi4 ji3 de bao1 fu2 wang3 xue2 xiao4 shuai3

识别结果: dan1 wei4 bu2 shi4 wo3 lao3 die1 kai1 de ping2 shen2 me yao4 yi1 ci4 er4 ci4 zhao4 gu4 wo3 wo3 bu4 neng2 ba3 zi4 ji3 de bao1 fu2 wang3 xue2 xiao4 shuai3

原文汉字: 单位不是我老爹开的凭什么要一次二次照顾我我不能把自己的包袱往学校甩

识别结果: 单位不是我老爹开的凭什么要一次二次照顾我我不能把自己的包袱往学校甩



示例 9

原文拼音: dou1 yong4 cao3 mao4 huo4 ge1 bo zhou3 hu4 zhe wan3 lie4 lie4 qie ju1 chuan1 guo4 lan4 ni2 tang2 ban1 de yuan4 ba4 pao3 hui2 zi4 ji3 de su4 she4 qu4 le

识别结果: dou1 yong4 cao3 mao4 huo4 ge1 bo zhou3 hu4 zhe wan3 lie4 lie4 qie ju1 chuan1 guo4 lan4 ni2 tang2 ban1 de yuan4 ba4 pao3 hui2 zi4 ji3 de su4 she4 qu4 le

原文汉字: 都用草帽或胳膊肘护着碗趔趔趄趄穿过烂泥塘般的院坝跑回自己的宿舍去了

识别结果: 都用草帽或胳膊肘护着碗趔趔趄趄穿过烂泥塘般的院坝跑回自己的宿舍去了



示例 10

原文拼音: xiang1 gang3 yan3 yi4 quan1 huan1 ying2 mao2 a1 min3 jia1 meng2 wu2 xian4 tai2 yu3 hua2 xing1 yi1 xie1 zhong4 da4 de yan3 chang4 huo2 dong4 dou1 yao1 qing3 ta1 chu1 chang3 you3 ji3 ci4 hai2 te4 yi4 an1 pai2 ya1 zhou4 yan3 chu1

识别结果: xiang1 gang3 yan3 yi4 quan1 huan1 ying2 mao2 a1 min3 jia1 meng2 wu2 xian4 tai2 yu3 hua2 xing1 yi1 xie1 zhong4 da4 de yan3 chang4 huo2 dong4 dou1 yao1 qing3 ta1 chu1 chang3 you3 ji3 ci4 hai2 te4 yi4 an1 pai2 ya1 zhou4 yan3 chu1

原文汉字: 香港演艺圈欢迎毛阿敏加盟无线台与华星一些重大的演唱活动都邀请她出场有几次还特意安排压轴演出

识别结果: 香港演艺圈欢迎毛阿敏加盟无线台与华星一些重大的演唱活动都邀请她出场有几次还特艺安排压轴演出

至此,一个简易的语音识别系统就搭建完成。

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。