Vits语音模型昇腾适配

举报
HuaweiCloudDeveloper 发表于 2026/01/07 11:18:19 2026/01/07
【摘要】 1      下载代码git clone https://github.com/jaywalnut310/vits.gitcd vits注意:有时候可能下载不了代码。这边给镜像加速,https://github.akams.cn/, 2      环境安装推荐使用conda环境安装依赖,这边参考:https://3ms.huawei.com/km/groups/3957721/blogs/d...

1      下载代码

git clone https://github.com/jaywalnut310/vits.git

cd vits

注意:有时候可能下载不了代码。这边给镜像加速,https://github.akams.cn/

 

2      环境安装

推荐使用conda环境安装依赖,这边参考:https://3ms.huawei.com/km/groups/3957721/blogs/details/21544735

目前代码文件中requirements.txt给到的环境比较低,这边给到完整的依赖下载

pip install torch==2.5.1 -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install pyyaml -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install torch-npu==2.5.1.post1 -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install torchaudio==2.5.1 -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install numpy==1.26.4 -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install soundfile pip install scipy -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install scipy -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install librosa -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install unidecode -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install phonemizer -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install cython -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple
pip install psutil -i https://mirrors.tuna.tsinghua.edu.cn/pypi/web/simple

 

3      权重下载

下载链接:https://drive.google.com/drive/folders/1ksarh-cJf3F5eKJjLVWY0X1j1qsQqiS2?usp=sharing   

目前改模型提供训练脚本,可以根据所需要的进行训练。

 

4      推理脚本构建

目前在仓库中给了参考的文件。

通过参考写出代码

import os
import json
import time
import torch
import torch_npu
import soundfile as sf

from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

#
自定义模块
import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence


# =============================
#
初始化配置
# =============================
config_path = "./configs/ljs_base.json"
model_path = "/home/yaowenjie/mobvoi/vits/pretrained_ljs.pth"
output_wav = "test.wav"
os.makedirs(profiling_dir, exist_ok=True)

#
设置 NPU 配置(根据你的硬件配置)
torch_npu.npu.config.allow_internal_format = False
torch_npu.npu.set_compile_mode(jit_compile=False)

#
设备选择
device = torch.device("npu:0")


# =============================
#
加载模型配置
# =============================
hps = utils.get_hparams_from_file(config_path)
print("
加载 config 成功")


# =============================
#
构建模型
# =============================
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model
)
net_g = net_g.eval().to(device)


# =============================
#
加载模型权重
# =============================
def load_checkpoint(checkpoint_path, model):
print(f"
🔄 加载 checkpoint: {checkpoint_path}")
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict.get("iteration", 0)
saved_state_dict = checkpoint_dict["model"]
model.load_state_dict(saved_state_dict)
print(f"
加载完成,迭代次数 {iteration}")
return model


net_g = load_checkpoint(model_path, net_g)


# =============================
#
文本预处理函数
# =============================
def get_text(text, hps):
text_norm = text_to_sequence(text, ['english_cleaners'])
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
return torch.LongTensor(text_norm)


# =============================
#
推理主流程
# =============================
if __name__ == "__main__":
#
输入文本
text_input = "VITS is Awesome!"
print(f"
📝 输入文本: {text_input}")

#
文本编码
stn_tst = get_text(text_input, hps).to(device=device)
x_tst = stn_tst.unsqueeze(0)
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device=device)

print(f"
📐 模型输入维度: {x_tst.shape}, 长度: {x_tst_lengths.item()}")

#
开始推理
time_start = time.time()
with torch.no_grad():
audio_tensor = net_g.infer(
x_tst, x_tst_lengths,
noise_scale=.667,
noise_scale_w=0.8,
length_scale=1
)[0]
time_end = time.time()

#
提取音频数据
audio = audio_tensor[0, 0].data.cpu().float().numpy()
duration = audio.shape[0] / hps.data.sampling_rate
print(f"
推理耗时: {time_end - time_start:.3f} ")
print(f"
🎵 合成音频长度: {duration:.2f} ")

#
保存音频文件
sf.write(output_wav, audio, samplerate=hps.data.sampling_rate, subtype='PCM_24')
print(f"
💾 音频已保存至: {output_wav}")

由于昇腾需要预热模型,第一次会比较慢,后续修改代码添加预热。


import os
import json
import time
import torch
import torch_npu
import soundfile as sf

from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

#
自定义模块
import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence


# =============================
#
初始化配置
# =============================
config_path = "./configs/ljs_base.json"
model_path = "/home/yaowenjie/mobvoi/vits/pretrained_ljs.pth"
output_wav = "test.wav"

#
确保输出目录存在(如果路径包含目录)
os.makedirs(os.path.dirname(output_wav) if os.path.dirname(output_wav) else '.', exist_ok=True)

#
设置 NPU 配置
torch_npu.npu.config.allow_internal_format = False
torch_npu.npu.set_compile_mode(jit_compile=False)

#
设备选择
device = torch.device("npu:0")


# =============================
#
加载模型配置
# =============================
hps = utils.get_hparams_from_file(config_path)
print("
加载 config 成功")


# =============================
#
构建模型
# =============================
net_g = SynthesizerTrn(
len(symbols),
hps.data.filter_length // 2 + 1,
hps.train.segment_size // hps.data.hop_length,
**hps.model
)
net_g = net_g.eval().to(device)


# =============================
#
加载模型权重
# =============================
def load_checkpoint(checkpoint_path, model):
print(f"
🔄 加载 checkpoint: {checkpoint_path}")
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
iteration = checkpoint_dict.get("iteration", 0)
saved_state_dict = checkpoint_dict["model"]
model.load_state_dict(saved_state_dict)
print(f"
加载完成,迭代次数 {iteration}")
return model


net_g = load_checkpoint(model_path, net_g)


# =============================
#
文本预处理函数
# =============================
def get_text(text, hps):
text_norm = text_to_sequence(text, ['english_cleaners'])
if hps.data.add_blank:
text_norm = commons.intersperse(text_norm, 0)
return torch.LongTensor(text_norm)


# =============================
#
推理主流程(含预热)
# =============================
if __name__ == "__main__":
print(f"
🔧 使用设备: {device}")

# ======== Warm-up
预热阶段 ========
print("
🔥 开始预热... (2 )")
net_g.eval()
with torch.no_grad():
for i in range(2):
#
使用短文本预热
dummy_text = "Hello" if i % 2 == 0 else "Test"
stn_tst_warm = get_text(dummy_text, hps).to(device)
x_tst_warm = stn_tst_warm.unsqueeze(0)
x_tst_lengths_warm = torch.LongTensor([stn_tst_warm.size(0)]).to(device)

#
执行推理
_ = net_g.infer(
x_tst_warm, x_tst_lengths_warm,
noise_scale=0.667,
noise_scale_w=0.8,
length_scale=1.0
)[0]
torch.npu.synchronize() #
确保每次推理完成

print("
预热完成")

# ========
正式推理 ========
text_input = "VITS is Awesome!"
print(f"
📝 输入文本: {text_input}")

#
文本编码
stn_tst = get_text(text_input, hps).to(device)
x_tst = stn_tst.unsqueeze(0) # [1, T]
x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)

print(f"
📐 模型输入维度: {x_tst.shape}, 长度: {x_tst_lengths.item()}")

#
同步并开始计时
torch.npu.synchronize()
time_start = time.time()

with torch.no_grad():
audio_tensor = net_g.infer(
x_tst, x_tst_lengths,
noise_scale=0.667,
noise_scale_w=0.8,
length_scale=1.0
)[0]

torch.npu.synchronize() #
确保推理完成
time_end = time.time()

#
提取音频
audio = audio_tensor[0, 0].data.cpu().float().numpy()
duration = audio.shape[0] / hps.data.sampling_rate
infer_time = time_end - time_start

print(f"
推理耗时: {infer_time:.3f} ")
print(f"
🎵 合成音频长度: {duration:.2f} ")
print(f"
📊 实时因子 (RTF): {infer_time / duration:.3f}")

#
保存音频
sf.write(output_wav, audio, samplerate=hps.data.sampling_rate, subtype='PCM_24')
print(f"
💾 音频已保存至: {output_wav}")




 



【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。