PyTorch 语音信号预处理实战

举报
yd_284014651 发表于 2025/06/28 11:05:41 2025/06/28
【摘要】 PyTorch 语音信号预处理实战深度学习的模型训练隶属于监督学习过程,通常是给定训练输入及对应的标签,然后使用某种损失计算方式与反向传播算法使得算法收敛。对于图片而言,分类任务与检测任务其制作标签就是需要在对数据集物体的类别与位置进行标记;而对于语音任务来说则截然不同。本实验以语音领域比较常见的’yesno’与’ashell’两个数据集为例,分别来介绍语音分类任务与语音识别任务的数据预处...

PyTorch 语音信号预处理实战

深度学习的模型训练隶属于监督学习过程,通常是给定训练输入及对应的标签,然后使用某种损失计算方式与反向传播算法使得算法收敛。对于图片而言,分类任务与检测任务其制作标签就是需要在对数据集物体的类别与位置进行标记;而对于语音任务来说则截然不同。

本实验以语音领域比较常见的’yesno’与’ashell’两个数据集为例,分别来介绍语音分类任务与语音识别任务的数据预处理过程。

实验目录如下:

  • 数据集下载
    • yesno数据集下载
    • ashell数据集下载
  • 语音信号分类预处理过程
    • 音频信号分割处理
    • 音频信号长度对齐处理
  • 语音信号识别预处理过程
    • 读取音频文件读取
    • CMVN(Cepstral Mean and Variance Normalization)计算方法,针对训练数据
    • 词典生成
    • 文本序列切词
    • 音频信号过滤
    • 重采样
    • 语速变换
    • 特征信息提取
    • 谱增强(spec_aug)
    • 数据打乱与排序(shuffle与sort)
    • 数据打包
    • 音频信号对齐(padding)

数据集下载

yesno数据集下载

由于yesno数据集被集成在torchaudio中,因此下载该模块以前需要先安装好torchaudio,可以使用’pip install torchaudio’命令安装。

数据集包含60个音频文件,文件命名格式类似0_0_1_1_0_1_0_1.wav,每个音频中包含与文件名对应的八个单词(yes或no),1代表yes,0代表no。

import torchaudio
from torchaudio.dataset import YESNO

# YESNO数据集会被下载到data_sava_path文件夹中
data_path = YESNO(root='data_sava_path', download=True)

ashell数据集下载

语音信号分类预处理过程

在深度学习的音频处理领域,音频分类是一种常见的应用场景,音频分类是指将音频信号划分为不同类别的过程。例如,可以将音频信号分为人声、背景音乐、噪音等不同类别。音频分类可以应用于多种场景,例如语音识别、音乐推荐、噪音消除等。

在深度学习领域,音频分类可以通过使用神经网络来实现。神经网络是一种模拟人脑神经元工作方式的计算模型,它可以用于解决各种复杂问题。在音频分类任务中,可以使用卷积神经网络(CNN)、递归神经网络(RNN)或者其他类型的神经网络来进行音频信号的处理和分类。对语音信号进行进行预处理可以使音频信号更符合神经网络的输入要求。

本节以’torchaudio.dataset’中的yesno数据集为例对语音分类的预处理过程进行介绍,主要分为信号音频分割与信号音频长度对齐两部分。

yesno信号音频分割

我们下载的音频文件中每个均包含了多个单词语义,每个单词是yes或no,因此,可以对数据预处理后拆分成单个单词使得一个音频文件对应一个单词,这样可以将音频分类问题转换为二分类问题,具体的分割流程如下:

对于个别音频文件,使用上述方法分割音频时并不能刚好将音频文件分割成8份,因此文件的编号并不能简单的用序号×8来计算,而是需要记录前一个文件分割成了几份。因此使用变量pre_len来存储前一个文件的分割份数。

pre_len = 8

#获取音频的频率随时间变化的序列及音频不同部分的标签(0 / 1)
for fi, filename in enumerate(os.listdir('./data_path')):
    waveform, label = torchaudio.load(f'./data_path/{filename}')

#设置阈值,本实验中阈值由均值与标准差之和得到
threshold = waveform[0].mean() + waveform[0].std()

#遍历频率序列,记录每个单词起止时的采样点。基于以下规则记录采样点:当频率高于阈值时,记此时的采样点为起始点,当多个采样点(如1000个采样点)的频率连续低于阈值时,记此时的采样点为终止点
start_times = []
end_times = []
segment_start = None
count = 0
for i in range(len(waveform[0])):
    if abs(waveform[0][i].item()) > threshold:
        count = 0
        if segment_start is None:
            segment_start = i

    elif abs(waveform[0][i].item()) < threshold and segment_start is not None:
        if count < 1000:
            count = count + 1
            continue
        else:
            start_times.append(segment_start)
            end_times.append(i)
            segment_start = None
            count = 0

# 根据起止采样点分割音频
segments = []
for i in range(len(start_times)):
    segment = waveform[0][start_times[i]:end_times[i]]
    segments.append(segment)

labels = filename.split('.')[0].split('_')
for i in range(len(labels), len(segments)):
    labels.append(i)

# 存储音频,音频文件的命名格式为:音频编号_音频标签.wav
fi = fi * pre_len
for i, segment in enumerate(segments):
    torchaudio.save(f'./data_path/waves_yesno_seg/{fi}_{labels[i]}.wav', segment.view(1, segment.size()[0]), sample_rate=8000)
    fi = fi + 1

pre_len = len(segments)


分割后一个数据就会被拆分成8份,这样就可以得到480个只含有yes或no的数据集合。

yesno信号音频长度对齐

由于输入层的大小需要人为设定并且在训练过程中不可更改,因此需要输入的数据的长度想等,由于分割后的音频长度必然存在不相等的情况,因此使用PyTorch提供的pad_sequence方法来对齐音频,对齐后可以看到每个文件大小均为最大的14k。

waveforms = []
labels = []
for fi, filename in enumerate(os.listdir('./data_path/waves_yesno_seg/')):
    waveform, sr = torchaudio.load(f'./data_path/waves_yesno_seg/{filename}')
    waveforms.append(waveform[0].view(-1))
    labels.append(filename.split('.')[0].split('_')[-1])

waveforms = pad_sequence(waveforms, batch_first=True)

# 13283为对齐后的音频长度,可以根据需要动态设置,选一般是取最大的音频长度进行对齐
for i, waveform in enumerate(waveforms):
    torchaudio.save(f'./data_path/waves_yesno_pad/{i}_{labels[i]}.wav', waveform.view(1, 13283), 8000)

语音信号识别预处理过程

我们从ashell数据集中选取了4个样本来分析整个语音识别信号的预处理过程。一个音频文件的输入部分包含两部分,一个是wav.scp格式语音内容,另一个是该语音内容对应的text文件。由于训练过程中是批量化训练一批数据,因此需要对文件做映射。

wav.scp 文件保存了语音编号和该语音在系统中的绝对路径,该文件的作用在于将语音编号和该语音的绝对路径对应起来,使得在声学特征提取以及数据增强阶段能够访问到该条语音,进而对语音进行处理。

text 文件保存了语音编号和该语音对应的转录文本。

CMVN(Cepstral Mean and Variance Normalization)计算方法,针对训练数据

由于训练过程中是将数据按照batch的方式送入模型,因此需要自定义AudioDataset方法用来解析scp文件中的内容。其中’arr = line.strip().split(",")‘这里是按照’,‘切分的,与上述文本中’,'分隔符相对应,解析以后的内容保存在’self.items’中供’getitem ‘方法调用,该方法重写了基类Dataset的’getitem’。'torchaudio.set_audio_backend(“sox_io”)'设置音频信号加载的后端,提供音频文件 I/O 功能的实现,在Linux/macOS 上的默认设置"sox_io",在windows下设置"soundfile" 。

import torch
import torchaudio
import torchaudio.compliance.kaldi as kaldi
from torch.utils.data import Dataset, DataLoader

torchaudio.set_audio_backend("sox_io")

class AudioDataset(Dataset):
    def __init__(self, data_file):
        self.items = []
        with codecs.open(data_file, 'r', encoding='utf-8') as f:
            for line in f:
                arr = line.strip().split(",")
                self.items.append((arr[0], arr[1]))

    def __len__(self):
        return len(self.items)

    def __getitem__(self, idx):
        return self.items[idx]

CMVN针对训练集合中的所有wav,假设每个wav是80个梅尔谱维度,在每个维度上做均值,方差操作,统计出所有batch的80个梅尔谱维度的均值、方差与总的帧数(mean_stat、var_stat与frame_num)。

通过定义了一个CollateFunc类实现,可以通过重设’resample_rate’来对数据设置新的采样频率, ‘waveform = waveform * (1 << 15)’,这里对波形数据进行了放大。入参’feat_dim’为提取的特征维度,resample_rate为提取特征过程中的重采样频率,该函数主要的功能是对所有输入的训练batch数据计算提取特征的均值与方差。

class CollateFunc(object):

    def __init__(self, feat_dim, resample_rate):
        self.feat_dim = feat_dim
        self.resample_rate = resample_rate
        pass

    def __call__(self, batch):
        mean_stat = torch.zeros(self.feat_dim)
        var_stat = torch.zeros(self.feat_dim)
        number = 0
        for item in batch:
            value = item[1].strip().split(",")
            assert len(value) == 3 or len(value) == 1
            wav_path = value[0]
            # 采样率会被覆盖掉,因此外部的sample_rate设置是无效的
            sample_rate = torchaudio.backend.sox_io_backend.info(wav_path).sample_rate
            resample_rate = sample_rate
            # len(value) == 3 means segmented wav.scp,
            # len(value) == 1 means original wav.scp
            if len(value) == 3:
                start_frame = int(float(value[1]) * sample_rate)
                end_frame = int(float(value[2]) * sample_rate)
                waveform, sample_rate = torchaudio.backend.sox_io_backend.load(
                    filepath=wav_path,
                    num_frames=end_frame - start_frame,
                    frame_offset=start_frame)
            else:
                waveform, sample_rate = torchaudio.load(item[1])

            waveform = waveform * (1 << 15)
            if self.resample_rate != 0 and self.resample_rate != sample_rate:
                resample_rate = self.resample_rate
                waveform = torchaudio.transforms.Resample(
                    orig_freq=sample_rate, new_freq=resample_rate)(waveform)

            mat = kaldi.fbank(waveform,
                              num_mel_bins=self.feat_dim,
                              dither=0.0,
                              energy_floor=0.0,
                              sample_frequency=resample_rate)
            mean_stat += torch.sum(mat, axis=0)
            var_stat += torch.sum(torch.square(mat), axis=0)
            number += mat.shape[0]
        return number, mean_stat, var_stat

设置特征提取维度’feat_dim=80’、'resample_rate = 16000’并调用’CollateFunc’与’AudioDataset’在上述4个语音数据上,data_loader的方差、均值与帧数情况如下:

feat_dim = 80
 resample_rate = 16000  collate_func = CollateFunc(feat_dim, resample_rate   dataset = AudioDataset(args.in_sc    batch_size 
    data_loader = DataLoader(dat       batch_size=batch        shuffl         sampl          num_workers=args.num           collate_fn=col
all_number = 0 #所有帧的数量
all_mean_stat = torch.zeros(feat_dim)
all_var_stat = torch.zeros(feat_dim)
wav_number = 0 #所有wav文件的数量
for i, batch in enumerate(data_loader):
    number, mean_stat, var_stat = batch
    all_mean_stat += mean_stat
    all_var_stat += var_stat
    all_number += number
    wav_number += batch_size
late_func)

词典生成

我们知道训练集中的语音信号对应的标签是一段的文字,实际训练过程中计算loss的时候不能以文字的形式直接对比,而是通过映射的方式将一个复杂的问题转化为一个数学问题,得到每个字符符的唯一表示数字。在训练过程中,我们会通过这个词典,将每条语音对应的转录文本转换为数字,我们将其称为token。

通过脚本转换对上述四个文本内容进行转换以后,生成的symbol_table词典如下:

[产: 2, 亿: 3, 付: 4, 令: 5, 估: 6, 低: 7, 体: 8, 例: 9, 值: 10, 偏: 11, 元: 12, 公: 13, 出: 14, 分: 15, 前: 16, 区: 17, 去: 18, 喜: 19, 国: 20, 地: 21, 场: 22, 外: 23, 州: 24, 已: 25, 市: 26, 广: 27, 总: 28, 息: 29, 我: 30, 房: 31, 报: 32, 据: 33, 日: 34, 昨: 35, 望: 36, 款: 37, 比: 38, 消: 39, 的: 40, 目: 41, 积: 42, 紧: 43, 者: 44, 购: 45, 贷: 46, 过: 47, 道: 48, 部: 49, 金: 50, 降: 51, 首: 52]

词典生成的核心作用是制作标签,将训练集中所有不重复的词生成唯一的数字表示。

读取音频文件读取

对于aishell-1数据集,使用raw格式即可(raw主要是针对于小规模数据集)。'parse_raw’中data格式要求如下:

{“key”: “BAC009S0002W0122”, “wav”: ./data_aishell/wav/train/S0002/BAC009S0002W0122.wav", “txt”: “而对楼市成交抑制作用最大的限购”}

其中包含音频id、音频路径和其对应的转录文本,对应代码中key、wav与txt。

函数通过’start’与’end’来设置开始采样与结束的采样帧用来设置截取整段音频中的某个片段。采集后的结果存在’example’中。

def parse_raw(data):
    """ Parse key/wav/txt from json line

        Args:
            data: Iterable[str], str is a json line has key/wav/txt

        Returns:
            Iterable[{key, wav, txt, sample_rate}]
    """
    for sample in data:
        assert 'src' in sample
        json_line = sample['src']
        obj = json.loads(json_line)
        assert 'key' in obj
        assert 'wav' in obj
        assert 'txt' in obj
        key = obj['key']
        wav_file = obj['wav']
        txt = obj['txt']
        try:
            if 'start' in obj:
                assert 'end' in obj
                sample_rate = torchaudio.backend.sox_io_backend.info(
                    wav_file).sample_rate
                start_frame = int(obj['start'] * sample_rate)
                end_frame = int(obj['end'] * sample_rate)
                waveform, _ = torchaudio.backend.sox_io_backend.load(
                    filepath=wav_file,
                    num_frames=end_frame - start_frame,
                    frame_offset=start_frame)
            else:
                waveform, sample_rate = torchaudio.load(wav_file)
            example = dict(key=key,
                           txt=txt,
                           wav=waveform,
                           sample_rate=sample_rate)
            yield example
        except Exception as ex:
            logging.warning('Failed to read {}'.format(wav_file))

文本序列切词

对data中标签文字进行切词并根据词表映射表symbol_table来计算label向量,转化为模型能够识别到标签,例如,"而对楼市成交抑制作用最大的限购"每个词都会被当成一个token,每个token会通过symbol_table转化为对应的数值并存储到label中。

def tokenize(data,
             symbol_table,
             bpe_model=None,
             non_lang_syms=None,
             split_with_space=False):
    """ Decode text to chars or BPE
        Inplace operation

        Args:
            data: Iterable[{key, wav, txt, sample_rate}]

        Returns:
            Iterable[{key, wav, txt, tokens, label, sample_rate}]
    """
    if non_lang_syms is not None:
        non_lang_syms_pattern = re.compile(r"(\[[^\[\]]+\]|<[^<>]+>|{[^{}]+})")
    else:
        non_lang_syms = {}
        non_lang_syms_pattern = None

    if bpe_model is not None:
        import sentencepiece as spm
        sp = spm.SentencePieceProcessor()
        sp.load(bpe_model)
    else:
        sp = None

    for sample in data:
        assert 'txt' in sample
        txt = sample['txt'].strip()
        if non_lang_syms_pattern is not None:
            parts = non_lang_syms_pattern.split(txt.upper())
            parts = [w for w in parts if len(w.strip()) > 0]
        else:
            parts = [txt]

        label = []
        tokens = []
        for part in parts:
            if part in non_lang_syms:
                tokens.append(part)
            else:
                if bpe_model is not None:
                    tokens.extend(__tokenize_by_bpe_model(sp, part))
                else:
                    if split_with_space:
                        part = part.split(" ")
                    for ch in part:
                        if ch == ' ':
                            ch = "▁"
                        tokens.append(ch)

        for ch in tokens:
            if ch in symbol_table:
                label.append(symbol_table[ch])
            elif '<unk>' in symbol_table:
                label.append(symbol_table['<unk>'])

        sample['tokens'] = tokens
        sample['label'] = label
        yield sample

音频信号过滤

每个语音片段会被采集成若干帧,这里会设置frame的最大最小值、token数量最大最小值,及token数/frame数比值的最大与最小区间来过滤掉一些噪音或者杂质语音信号。

def filter(data,
           max_length=10240,
           min_length=10,
           token_max_length=200,
           token_min_length=1,
           min_output_input_ratio=0.0005,
           max_output_input_ratio=1):
    """ Filter sample according to feature and label length
        Inplace operation.

        Args::
            data: Iterable[{key, wav, label, sample_rate}]
            max_length: drop utterance which is greater than max_length(10ms)
            min_length: drop utterance which is less than min_length(10ms)
            token_max_length: drop utterance which is greater than
                token_max_length, especially when use char unit for
                english modeling
            token_min_length: drop utterance which is
                less than token_max_length
            min_output_input_ratio: minimal ration of
                token_length / feats_length(10ms)
            max_output_input_ratio: maximum ration of
                token_length / feats_length(10ms)

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'wav' in sample
        assert 'label' in sample
        # sample['wav'] is torch.Tensor, we have 100 frames every second
        num_frames = sample['wav'].size(1) / sample['sample_rate'] * 100
        if num_frames < min_length:
            continue
        if num_frames > max_length:
            continue
        if len(sample['label']) < token_min_length:
            continue
        if len(sample['label']) > token_max_length:
            continue
        if num_frames != 0:
            if len(sample['label']) / num_frames < min_output_input_ratio:
                continue
            if len(sample['label']) / num_frames > max_output_input_ratio:
                continue
        yield sample


重采样

一般来说语音信号采集会有一个默认的采集频率,这里可以通过resample来设置语音采样频率从而对语音信号进行重新采样。

def resample(data, resample_rate=16000):
    """ Resample data.
        Inplace operation.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            resample_rate: target resample rate

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'wav' in sample
        sample_rate = sample['sample_rate']
        waveform = sample['wav']
        if sample_rate != resample_rate:
            sample['sample_rate'] = resample_rate
            sample['wav'] = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=resample_rate)(waveform)
        yield sample

语速变换

假如话语比较流畅,中间没有经常的或较长的无声停顿, 那么音速快语速就快,音速慢语速也就慢。 相反,如果 中间有许多无声停顿,或者无声停顿较长,那么,即使 音速较快,总体语速也可能是较慢的。因此,可以通过’speed_perturb’来设置语音的语速进行设置,类似于加速或者减缓语音的速度,这里支持三种倍率’0.9’, ‘1.0’与’1.1’。其中’0.9’:0.9倍语速。‘1.0’:语速不变,‘1.1’: 1.1倍语速。

def speed_perturb(data, speeds=None):
    """ Apply speed perturb to the data.
        Inplace operation.

        Args:
            data: Iterable[{key, wav, label, sample_rate}]
            speeds(List[float]): optional speed

        Returns:
            Iterable[{key, wav, label, sample_rate}]
    """
    if speeds is None:
        speeds = [0.9, 1.0, 1.1]
    for sample in data:
        assert 'sample_rate' in sample
        assert 'wav' in sample
        sample_rate = sample['sample_rate']
        waveform = sample['wav']
        speed = random.choice(speeds)
        if speed != 1.0:
            wav, _ = torchaudio.sox_effects.apply_effects_tensor(
                waveform, sample_rate,
                [['speed', str(speed)], ['rate', str(sample_rate)]])
            sample['wav'] = wav

        yield sample

特征信息提取

语音信号送入神经网络进行训练也需要进行特征提取,这里提供了两种比较常见的特征提取方法用于提取 fbank或mfcc特征。

def compute_fbank(data,
                  num_mel_bins=23,
                  frame_length=25,
                  frame_shift=10,
                  dither=0.0):
    """ Extract fbank

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'wav' in sample
        assert 'key' in sample
        assert 'label' in sample
        sample_rate = sample['sample_rate']
        waveform = sample['wav']
        waveform = waveform * (1 << 15)
        # Only keep key, feat, label
        mat = kaldi.fbank(waveform,
                          num_mel_bins=num_mel_bins,
                          frame_length=frame_length,
                          frame_shift=frame_shift,
                          dither=dither,
                          energy_floor=0.0,
                          sample_frequency=sample_rate)
        yield dict(key=sample['key'], label=sample['label'], feat=mat)


def compute_mfcc(data,
                 num_mel_bins=23,
                 frame_length=25,
                 frame_shift=10,
                 dither=0.0,
                 num_ceps=40,
                 high_freq=0.0,
                 low_freq=20.0):
    """ Extract mfcc

        Args:
            data: Iterable[{key, wav, label, sample_rate}]

        Returns:
            Iterable[{key, feat, label}]
    """
    for sample in data:
        assert 'sample_rate' in sample
        assert 'wav' in sample
        assert 'key' in sample
        assert 'label' in sample
        sample_rate = sample['sample_rate']
        waveform = sample['wav']
        waveform = waveform * (1 << 15)
        # Only keep key, feat, label
        mat = kaldi.mfcc(waveform,
                         num_mel_bins=num_mel_bins,
                         frame_length=frame_length,
                         frame_shift=frame_shift,
                         dither=dither,
                         num_ceps=num_ceps,
                         high_freq=high_freq,
                         low_freq=low_freq,
                         sample_frequency=sample_rate)
        yield dict(key=sample['key'], label=sample['label'], feat=mat)

谱增强(spec_aug)

该方法是数据增强的一种方式,主要是在时域与频域上分别随机的删除几个小的连续范围,让模型预测整个序列,类似于充当一个在线噪音的功能。这里可以采用’torchlibrosa.augmentation’中的’SpecAugmentation’进行替换。

def spec_aug(data, num_t_mask=2, num_f_mask=2, max_t=50, max_f=10, max_w=80):
    """ Do spec augmentation
        Inplace operation

        Args:
            data: Iterable[{key, feat, label}]
            num_t_mask: number of time mask to apply
            num_f_mask: number of freq mask to apply
            max_t: max width of time mask
            max_f: max width of freq mask
            max_w: max width of time warp

        Returns
            Iterable[{key, feat, label}]
    """
    for sample in data:
        assert 'feat' in sample
        x = sample['feat']
        assert isinstance(x, torch.Tensor)
        y = x.clone().detach()
        max_frames = y.size(0)
        max_freq = y.size(1)
        # time mask
        for i in range(num_t_mask):
            start = random.randint(0, max_frames - 1)
            length = random.randint(1, max_t)
            end = min(max_frames, start + length)
            y[start:end, :] = 0
        # freq mask
        for i in range(num_f_mask):
            start = random.randint(0, max_freq - 1)
            length = random.randint(1, max_f)
            end = min(max_freq, start + length)
            y[:, start:end] = 0
        sample['feat'] = y
        yield sample

数据打乱与排序(shuffle与sort)

shuffle函数作用是将buf中凑够1500个样本以后,重新洗牌得到新的随机序列信息。sort函数椒等到buf中凑够500个样本以后,按照fram数量排序。

def shuffle(data, shuffle_size=10000):
    """ Local shuffle the data

        Args:
            data: Iterable[{key, feat, label}]
            shuffle_size: buffer size for shuffle

        Returns:
            Iterable[{key, feat, label}]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= shuffle_size:
            random.shuffle(buf)
            for x in buf:
                yield x
            buf = []
    # The sample left over
    random.shuffle(buf)
    for x in buf:
        yield x


def sort(data, sort_size=500):
    """ Sort the data by feature length.
        Sort is used after shuffle and before batch, so we can group
        utts with similar lengths into a batch, and `sort_size` should
        be less than `shuffle_size`

        Args:
            data: Iterable[{key, feat, label}]
            sort_size: buffer size for sort

        Returns:
            Iterable[{key, feat, label}]
    """

    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= sort_size:
            buf.sort(key=lambda x: x['feat'].size(0))
            for x in buf:
                yield x
            buf = []
    # The sample left over
    buf.sort(key=lambda x: x['feat'].size(0))
    for x in buf:
        yield x

数据打包

该模块主要是提供给训练过程使用,将输入数据按照指定大小输入到模型中,这里提供了’static_batch’与’dynamic_batch’两种数据打包方式。

def static_batch(data, batch_size=16):
    """ Static batch the data by `batch_size`

        Args:
            data: Iterable[{key, feat, label}]
            batch_size: batch size

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    for sample in data:
        buf.append(sample)
        if len(buf) >= batch_size:
            yield buf
            buf = []
    if len(buf) > 0:
        yield buf


def dynamic_batch(data, max_frames_in_batch=12000):
    """ Dynamic batch the data until the total frames in batch
        reach `max_frames_in_batch`

        Args:
            data: Iterable[{key, feat, label}]
            max_frames_in_batch: max_frames in one batch

        Returns:
            Iterable[List[{key, feat, label}]]
    """
    buf = []
    longest_frames = 0
    for sample in data:
        assert 'feat' in sample
        assert isinstance(sample['feat'], torch.Tensor)
        new_sample_frames = sample['feat'].size(0)
        longest_frames = max(longest_frames, new_sample_frames)
        frames_after_padding = longest_frames * (len(buf) + 1)
        if frames_after_padding > max_frames_in_batch:
            yield buf
            buf = [sample]
            longest_frames = new_sample_frames
        else:
            buf.append(sample)
    if len(buf) > 0:
        yield buf


def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000):
    """ Wrapper for static/dynamic batch
    """
    if batch_type == 'static':
        return static_batch(data, batch_size)
    elif batch_type == 'dynamic':
        return dynamic_batch(data, max_frames_in_batch)
    else:
        logging.fatal('Unsupported batch type {}'.format(batch_type))

音频信号对齐(padding)

由于输入语音信号语速、大小均参差不齐,因此需要对输入进行以便送入至模型的size维度一致,padding作用就是对一个batch中的feat和labels的长度进行pad对齐。

def padding(data):
    """ Padding the data into training data

        Args:
            data: Iterable[List[{key, feat, label}]]

        Returns:
            Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
    """
    for sample in data:
        assert isinstance(sample, list)
        feats_length = torch.tensor([x['feat'].size(0) for x in sample],
                                    dtype=torch.int32)
        order = torch.argsort(feats_length, descending=True)
        feats_lengths = torch.tensor(
            [sample[i]['feat'].size(0) for i in order], dtype=torch.int32)
        sorted_feats = [sample[i]['feat'] for i in order]
        sorted_keys = [sample[i]['key'] for i in order]
        sorted_labels = [
            torch.tensor(sample[i]['label'], dtype=torch.int64) for i in order
        ]
                                     dtype=torch.int32)

        padded_feats = pad_sequence(sorted_feats,
                                    batch_first=True,
                                    padding_value=0)
        padding_labels = pad_sequence(sorted_labels,
                                      batch_first=True,
                                      padding_value=-1)

        yield (sorted_keys, padded_feats, padding_labels, feats_lengths,
               label_lengths)

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。