WeNet云端推理部署代码解析(上)
WeNet是一款开源端到端ASR工具包,它与ESPnet等开源语音项目相比,最大的优势在于提供了从训练到部署的一整套工具链,使ASR服务的工业落地更加简单。如图1所示,WeNet工具包完全依赖于PyTorch生态:使用TorchScript进行模型开发,使用Torchaudio进行动态特征提取,使用DistributedDataParallel进行分布式训练,使用torch JIT(Just In Time)进行模型导出,使用LibTorch作为生产环境运行时。本系列将对WeNet云端推理部署代码进行解析。
图1:WeNet系统设计[1]
1. 代码结构
WeNet云端推理和部署代码位于wenet/runtime/server/x86路径下,编程语言为C++,其结构如下所示:
其中:
- 语音文件读入与特征提取相关代码位于frontend文件夹下;
- 端到端模型导入、端点检测与语音解码识别相关代码位于decoder文件夹下,WeNet支持CTC prefix beam search和融合了WFST的CTC beam search这两种解码算法,后者的实现大量借鉴了Kaldi,相关代码放在kaldi文件夹下;
- 在服务化方面,WeNet分别实现了基于WebSocket和基于gRPC的两套服务端与客户端,基于WebSocket的实现位于websocket文件夹下,基于gRPC的实现位于grpc文件夹下,两种实现的入口main函数代码都位于bin文件夹下。
- 日志、计时、字符串处理等辅助代码位于utils文件夹下。
WeNet提供了CMakeLists.txt和Dockerfile,使得用户能方便地进行项目编译和镜像构建。
2. 前端:frontend文件夹
1)语音文件读入
WeNet只支持44字节header的wav格式音频数据,wav header定义在WavHeader结构体中,包括音频格式、声道数、采样率等音频元信息。WavReader类用于语音文件读入,调用fopen打开语音文件后,WavReader先读入WavHeader大小的数据(也就是44字节),再根据WavHeader中的元信息确定待读入音频数据的大小,最后调用fread把音频数据读入buffer,并通过static_cast把数据转化为float类型。
struct WavHeader {
char riff[4]; // "riff"
unsigned int size;
char wav[4]; // "WAVE"
char fmt[4]; // "fmt "
unsigned int fmt_size;
uint16_t format;
uint16_t channels;
unsigned int sample_rate;
unsigned int bytes_per_second;
uint16_t block_size;
uint16_t bit;
char data[4]; // "data"
unsigned int data_size;
};
这里存在的一个风险是,如果WavHeader中存放的元信息有误,则会影响到语音数据的正确读入。
2)特征提取
WeNet使用的特征是fbank,通过FeaturePipelineConfig结构体进行特征设置。默认帧长为25ms,帧移为10ms,采样率和fbank维数则由用户输入。
用于特征提取的类是FeaturePipeline。为了同时支持流式与非流式语音识别,FeaturePipeline类中设置了input_finished_属性来标志输入是否结束,并通过set_input_finished()成员函数来对input_finished_属性进行操作。
提取出来的fbank特征放在feature_queue_中,feature_queue_的类型是BlockingQueue<std::vector<float>>。BlockingQueue类是WeNet实现的一个阻塞队列,初始化的时候需要提供队列的容量(capacity),通过Push()函数向队列中增加特征,通过Pop()函数从队列中读取特征:
- 当feature_queue_中的feature数量超过capacity,则Push线程被挂起,等待feature_queue_.Pop()释放出空间。
- 当feature_queue_为空,则Pop线程被挂起,等待feature_queue_.Push()。
线程的挂起和恢复是通过C++标准库中的线程同步原语std::mutex、std::condition_variable等实现。
线程同步还用在AcceptWaveform和ReadOne两个成员函数中,AcceptWaveform把语音数据提取得到的fbank特征放到feature_queue_中,ReadOne成员函数则把特征从feature_queue_中读出,是经典的生产者消费者模式。
3. 解码器:decoder文件夹
1)TorchAsrModel
通过torch::jit::load对存在磁盘上的模型进行反序列化,得到一个ScriptModule对象。
torch::jit::script::Module model = torch::jit::load(model_path);
2)SearchInterface
WeNet推理支持的解码方式都继承自基类SearchInterface,如果要新增解码算法,则需继承SearchInterface类,并提供该类中所有纯虚函数的实现,包括:
// 解码算法的具体实现
virtual void Search(const torch::Tensor& logp) = 0;
// 重置解码过程
virtual void Reset() = 0;
// 结束解码过程
virtual void FinalizeSearch() = 0;
// 解码算法类型,返回一个枚举常量SearchType
virtual SearchType Type() const = 0;
// 返回解码输入
virtual const std::vector<std::vector<int>>& Inputs() const = 0;
// 返回解码输出
virtual const std::vector<std::vector<int>>& Outputs() const = 0;
// 返回解码输出对应的似然值
virtual const std::vector<float>& Likelihood() const = 0;
// 返回解码输出对应的次数
virtual const std::vector<std::vector<int>>& Times() const = 0;
目前WeNet只提供了SearchInterface的两种子类实现,也即两种解码算法,分别定义在CtcPrefixBeamSearch和CtcWfstBeamSearch两个类中。
3)CtcEndpoint
WeNet支持语音端点检测,提供了一种基于规则的实现方式,用户可以通过CtcEndpointConfig结构体和CtcEndpointRule结构体进行规则配置。WeNet默认的规则有三条:
- 检测到了5s的静音,则认为检测到端点;
- 解码出了任意时长的语音后,检测到了1s的静音,则认为检测到端点;
- 解码出了20s的语音,则认为检测到端点。
一旦检测到端点,则结束解码。另外,WeNet把解码得到的空白符(blank)视作静音。
4)TorchAsrDecoder
WeNet提供的解码器定义在TorchAsrDecoder类中。如图3所示,WeNet支持双向解码,即叠加从左往右解码和从右往左解码的结果。在CTC beam search之后,用户还可以选择进行attention重打分。
图2:WeNet解码计算流程[2]
可以通过DecodeOptions结构体进行解码参数配置,包括如下参数:
struct DecodeOptions {
int chunk_size = 16;
int num_left_chunks = -1;
float ctc_weight = 0.0;
float rescoring_weight = 1.0;
float reverse_weight = 0.0;
CtcEndpointConfig ctc_endpoint_config;
CtcPrefixBeamSearchOptions ctc_prefix_search_opts;
CtcWfstBeamSearchOptions ctc_wfst_search_opts;
};
其中,ctc_weight表示CTC解码权重,rescoring_weight表示重打分权重,reverse_weight表示从右往左解码权重。最终解码打分的计算方式为:
final_score = rescoring_weight * rescoring_score + ctc_weight * ctc_score;
rescoring_score = left_to_right_score * (1 - reverse_weight) +
right_to_left_score * reverse_weight
TorchAsrDecoder对外提供的解码接口是Decode(),重打分接口是Rescoring()。Decode()返回的是枚举类型DecodeState,包括三个枚举常量:kEndBatch,kEndpoint和kEndFeats,分别表示当前批数据解码结束、检测到端点、所有特征解码结束。
为了支持长语音识别,WeNet还提供了连续解码接口ResetContinuousDecoding(),它与解码器重置接口Reset()的区别在于:连续解码接口会记录全局已经解码的语音帧数,并保留当前feature_pipeline_的状态。
总结
本文主要对WeNet云端推理代码进行探索,介绍了代码结构、前端和解码器部分代码。在《WeNet云端推理部署代码解析(下)》中,笔者将继续解析WeNet云端部署代码。
参考
[1] WeNet: Production First and Production Ready End-to-End Speech Recognition Toolkit
[2] U2++: Unified Two-pass Bidirectional End-to-end Model for Speech Recognition
- 点赞
- 收藏
- 关注作者
评论(0)