MindIE-LLM ATB模型推理全流程解析

举报
AI布道Mr_jin 发表于 2025/06/26 14:23:05 2025/06/26
【摘要】 最近,有很多小伙伴问我,如果他们想自己基于MindIE镜像中的文件适配新模型,可以怎么做?为了实现这个目标,首先需要了解MindIE-LLM模型在推理过程中的代码调用流程,然后根据新模型的算法进行适配。 背景知识MindIE-LLM组件采用ATB算子构建模型。ATB全称Ascend transformer boost,是一款高效、可靠的加速库,基于华为Ascend AI处理器,专门为Tran...

最近,有很多小伙伴问我,如果他们想自己基于MindIE镜像中的文件适配新模型,可以怎么做?

为了实现这个目标,首先需要了解MindIE-LLM模型在推理过程中的代码调用流程,然后根据新模型的算法进行适配。

背景知识

MindIE-LLM组件采用ATB算子构建模型。ATB全称Ascend transformer boost,是一款高效、可靠的加速库,基于华为Ascend AI处理器,专门为Transformer模型的训练和推理而设计。开发者可以使用ATB算子组图,实现大模型的整图高性能推理,详情可以参考官网链接:https://www.hiascend.com/document/detail/zh/CANNCommunityEdition/82RC1alpha002/acce/ascendtb/ascendtb_0001.html

代码入口

本文以llama模型为例,从入口脚本run_pa.py开始,分析模型路由、模型实例化(权重导入)和图构建推理的过程。

MindIE-LLM ATB模型的推理入口文件在官网MindIE镜像的这个位置:/usr/local/Ascend/atb-models/examples/run_pa.py。这个文件的核心代码如下:

pa_runner = PARunner(**input_dict)
print_log(rank, logger.info, f'pa_runner: {pa_runner}')
pa_runner.warm_up()

infer_params = {
    "inputs": infer_inputs,
    "batch_size": args.max_batch_size,
    "max_output_length": args.max_output_length,
    "ignore_eos": args.ignore_eos,
    "is_chat_model": args.is_chat_model
}
generate_texts, token_nums, _ = pa_runner.infer(**infer_params)

pa_runner实例化的过程中包含了模型类的路由、权重导入和计算图构建,接下来我们逐个分析。

模型类路由

这部分的功能是根据用户传入的config参数获取模型类。

上图是模型类路由的代码调用流程。PARunnerinit函数会调用self.model = ModelRunner()进行模型类的获取。ModelRunner定义在model_runner.py文件中。ModelRunnerinit函数调用router_ins = get_model获取模型信息。我们来看一下get_model()函数:

...
router_path = f"atb_llm.models.{model_type}.router_{model_type}"
if model_type == "qwen2_moe" or model_type == "qwen3_moe":
    model_type = model_type.replace('_', '')
if model_type == "qwen2_audio":
    model_type = model_type.replace('_', '')
if model_type == "qwen2_vl":
    model_type = model_type.replace('_', '')
if model_type == "minicpm_qwen2_v2":
    model_type = model_type.replace('_', '')
router = importlib.import_module(router_path)
router_cls = getattr(router, f"{model_type.capitalize()}Router")
router_ins = router_cls(
    model_name_or_path,
    config_dict,
    is_flash_causal_lm,
    load_tokenizer,
    max_position_embeddings,
    revision,
    tokenizer_path,
    trust_remote_code,
    enable_atb_torch,
    enable_edge,
    enable_refactor,
    llm_config)
return router_ins

从上面代码的第1行可以看到,这个函数根据config文件中的model_type找到了llama模型路由的位置atb_llm\models\llama\router_llama.py,以及router_cls=LlamaRouter()

然后回到ModelRunnerinit函数中,运行了self.model_cls = router_ins.model_cls来获得模型类。LlamaRoutermodel_cls()函数定义在它的基类BaseRouter里面:

def get_model_cls(self):
	...
    model_cls_name = f"{self.model_type_cap}ForCausalLM"
    if self.enable_atb_torch:
        model_cls_name += "ATB"
    if self.is_flash_causal_lm:
        model_cls_name = "Flash" + model_cls_name
    if self.enable_refactor:
        model_cls_name += "V2"
    return getattr(module, model_cls_name)

可以看到,这段代码根据model_cls_name找到了模型类FlashLlamaForCausalLMATB以及它的文件名flash_causal_llama_atb.py。需要注意的是,此时只是获取了模型类,还没有做实例化。

打开代码仓的同学应该发现了,router_llama.pyflash_causal_llama_atb.py都放在atb_models\atb_llm\models\llama目录下。所以,如果你想重新适配一个模型,那么也需要在atb_models\atb_llm\models目录下创建一个新模型对应的目录,并且实现这些文件。

模型实例化&权重导入

PARunner获取到模型类之后,继续调用self.model.load_weights把权重加载到模型中(同时完成了模型实例化),代码调用流程如下:

load_weights函数的主要逻辑如下,包括模型的实例化和模型下发到device:

self.model = self.model_cls(...)
...
self.model.to(weights.device)

FlashLlamaForCausalLMATB的初始化函数中调用了self.model = LlamaModelATB()构建模型,我们继续看一下LlamaModelATB的初始化函数:

...
is_parallel = config.vocab_size >= LLAMA_EMBEDDING_PARALLEL_THRESHOLD
super().__init__(config, weights, model_prefix, lm_head_prefix, is_parallel, is_fa, backend)

self.layers = nn.ModuleList(
    [LlamaLayer(layer_idx, config, weights, model_prefix, self.is_fa, self.backend, speculate_enable) \
        for layer_idx in range(config.num_hidden_layers)])

linear_info = LmHeadLinearInfo()
linear_info.lm_head_name = lm_head_prefix
self.norm = BaseRMSNorm(f"{model_prefix}.norm", config, weights, linear_info)

self.layers又调用了class LlamaLayer定义每一层的结构,详情如下:

...
# 模型结构
self.self_attn = LlamaAttention(
    config=config, weights=weights, prefix=f"{prefix}.self_attn", norm_prefix=f"{prefix}.input_layernorm", \
    is_fa=self.is_fa, backend=backend, speculate_enable=self.speculate_enable)

self.mlp = BaseMLP(
    prefix=f"{prefix}.mlp", config=config, weights=weights,
    norm_prefix=f"{prefix}.post_attention_layernorm", backend=backend)

self.input_layernorm = BaseRMSNorm(
    f"{prefix}.input_layernorm", config, weights, self.self_attn.linear_info)

self.post_attention_layernorm = BaseRMSNorm(
    f"{prefix}.post_attention_layernorm", config, weights, self.mlp.linear_info)

可以看到,上面把transformer层中的attention、mlp和norm层都进行了定义,如果继续观察每一层的初始化函数,可以发现是调用了pytorch的linear算子接口或者nn.Parameter来加载权重,然后把线性层信息保存到self.linear_info变量,下一步进行图构建会用到这个变量。

计算图构建

ModelRunner.load_weights完成权重加载后,继续调用self.model.init_graph()进行ATB算子的调用和计算图构建。self.model对应的是FlashLlamaForCausalLMATB,其init_graph函数继承自基类FlashForCausalLMATB

def init_graph(self):
    """Initialze weight, prefill graph and decode graph."""
    # 获取权重键值对
    self.weight = self.get_weights()
    # 创建atb graph
    self.prefill_graph = AtbGraph(f"{self.name}_prefill_graph")
    self.build_graph(self.prefill_graph, is_prefill=True)
    self.decode_graph = AtbGraph(f"{self.name}_decode_graph")
    self.build_graph(self.decode_graph, is_prefill=False)

可以看到,这个函数初始化了2个AtbGraph,分别对应首token计算图和增量计算图。AtbGraph继承自atb._GraphOperation,是C++的pybind接口,目前这部分代码没有开源。初始化ATB图后,又调用了self.build_graphFlashLlamaForCausalLMATB.build_graph()定义如下:

def build_graph(self, graph, is_prefill):
    # 设置输入输出
    kv_cache_names = []
    for i in range(self.config.num_hidden_layers):
        kv_cache_names.extend([f"layer_{i}_k_cache", f"layer_{i}_v_cache"])
    graph.add_input_output(
        input=list(self.weight.keys()) + kv_cache_names + self.get_in_tensor_names(is_prefill),
        output=self.get_out_tensor_names())

    # 增加图节点
    self.model.build_graph(graph, is_prefill)
    self.build_lm_head(graph, is_prefill)

    # 构图
    graph.execute_as_single = False
    graph.build()

首先准备了输入输出,然后调用self.model.build_graph构图,对应的是LlamaModelATB.build_graph()

def build_graph(self, graph, is_prefill):
    self.build_word_embedding_graph(graph)
    self.build_positional_embedding_graph(graph)
    for layer in self.layers:
        layer.build_graph(graph, is_prefill)
    self.norm.build_graph(graph, is_prefill)

可以看到,代码逻辑是把每一层都build到graph里面去,我们继续打开LlamaLayer.build_graph()

def build_graph(self, graph, is_prefill):
    self.layer_graph = AtbGraph(("prefill" if is_prefill else "decode") + f"_layer_{self.layer_id}_graph")
    self.layer_graph.add_input_output(
        input=self.weight_names + ["k_cache", "v_cache"] + self.get_in_tensor_names(is_prefill),
        output=["layer_out"])
    if self.is_reshape:
        self.layer_graph.add_reshape("hidden_states", "hidden_states", self.reshape_parallel)
    self.input_layernorm.build_graph(self.layer_graph, is_prefill)
    self.self_attn.build_graph(self.layer_graph, is_prefill)
    self.post_attention_layernorm.build_graph(self.layer_graph, is_prefill)
    self.mlp.build_graph(self.layer_graph, is_prefill)
    self.layer_graph.build()

    graph.operations.append(self.layer_graph)
    graph.add_operation(self.layer_graph, self.weight_names + \
    [f"layer_{self.layer_id}_k_cache", f"layer_{self.layer_id}_v_cache"] + self.get_in_tensor_names(
        is_prefill), ["hidden_states"])

这段代码首先建立了一个子图self.layer_graph,然后把norm层、attention层和mlp层都进行build。我们以self_attn.build_graph为例继续打开:

def build_graph(self, graph, is_prefill):
    atten_res_add = atb._BaseOperation(op_type="Elewise", 				op_param=json.dumps({'elewiseType': 'ELEWISE_ADD'}),
                                       op_name='atten_res_add')
    setattr(graph, 'atten_res_add', atten_res_add)

    self.build_qkv_graph(graph)
    self.build_rope_graph(graph)
    self.build_attention_graph(graph, is_prefill)
    self.build_dense_graph(graph, is_prefill)

    graph.add_operation(graph.atten_res_add, ['hidden_states', 'dense_out'], ['hidden_states'])

这里面又包含了qkv的计算、attention计算和输出映射层的计算,我们看一下build_attention_graph是如何调用ATB算子的:

def build_attention_graph(self, graph, is_prefill):
	...
    pa_attention_builder = CommonOpBuilderManager.get_builder(attention_param)
    graph = pa_attention_builder.build(graph, attention_tensor_map)

可以看到,这里通过CommonOpBuilderManager.get_builder获得了pa_attention算子的builder。CommonOpBuilderManager是定义在common_op_builder_manager.py里面的类,它的功能是把transformer模型通用的算子进行管理,方便用户构建模型的时候调用。它的代码实现如下::

class CommonOpBuilderManager:
    _common_op_builders = []

    @classmethod
    def register(cls, common_op_builder_class):
        cls._common_op_builders.append(common_op_builder_class())

    @classmethod
    def get_builder(cls, param: dict) -> BaseCommonOpBuilder | None:
        for common_op_builder in cls._common_op_builders:
            if common_op_builder.is_match(param):
                return common_op_builder
        print_log(ENV.rank, logger.debug, f"CommonOpBuilder not found for param: {param}")
        raise RuntimeError(f"CommonOpBuilder not found for param: {param}")

注意到,它的get_builder函数可以根据传入的param返回对应的算子builder。而且字典变量_common_op_builders里面的值是通过调用register进行更新的。大家可能有疑问,这个register函数是在哪里被调用的呢?实际上是在atb-models/atb_llm/common_op_builders下面的每类算子的__init__.py中执行的,比如atb_models\atb_llm\common_op_builders\attention

from atb_llm.common_op_builders.common_op_builder_manager import CommonOpBuilderManager
from atb_llm.common_op_builders.attention.atb_decoder_paged_attention_common_op_builder import \
    ATBDecoderPagedAttentionCommonOpBuilder
from atb_llm.common_op_builders.attention.atb_encoder_paged_attention_common_op_builder import \
    ATBEncoderPagedAttentionCommonOpBuilder
from atb_llm.common_op_builders.attention.atb_flash_attention_common_op_builder import \
    ATBFlashAttentionCommonOpBuilder

CommonOpBuilderManager.register(ATBDecoderPagedAttentionCommonOpBuilder)
CommonOpBuilderManager.register(ATBEncoderPagedAttentionCommonOpBuilder)
CommonOpBuilderManager.register(ATBFlashAttentionCommonOpBuilder)

对于prefill_graph,我们结合build_attention_graph()函数中的attention_param

attention_param = {
    "op_name": "attention",
    "category": CommonOpBuilderType.ATTENTION,
    "is_prefill": is_prefill,
    "attn_type": AttnType.FLASH_ATTENTION if self.is_fa else AttnType.PAGED_ATTENTION,
    "head_size": self.head_size,
    "atb_reshape_and_cache_param": {},
    "operation_backend": OperationBackend.ATB,
    "atb_attention_param": self._get_atb_attention_param(is_prefill)
}

以及ATBEncoderPagedAttentionCommonOpBuilderis_match()函数,可知获取的op_builder是ATBEncoderPagedAttentionCommonOpBuilder类,它的build()函数逻辑如下:

def build(self, graph: atb._GraphOperation, tensor_map: dict) -> atb._GraphOperation:
    ...
    # self attention
    attention_op = atb._BaseOperation(
        op_type="SelfAttention",
        op_param=json.dumps(self.param.atb_attention_param),
        op_name=f"{self.param.op_name}_SelfAttention"
    )
    graph.operations.append(attention_op)
    ...
    return graph

可以看到,这里通过atb._BaseOperation接口调用了atb算子。

其他算子的调用逻辑也同理,大家可以自己查看一遍。

总结

这篇文章主要分析了ATB模型推理的代码调用栈,同时给出了新模型适配涉及的代码目录。ATB模型的适配代码目录在/usr/local/Ascend/atb-models/atb_llm/models,以llama模型为例,/usr/local/Ascend/atb-models/atb_llm/models/llama下面包含模型路由脚本router_llama.py以及模型类的定义脚本flash_causal_llama_atb.pymodeling_llama_atb.py。如果需要适配新的模型,需要在/models下面创建新的目录并实现上述脚本内容。

MindIE-LLM提供了构建transformer模型的通用算子,统一放在/usr/local/Ascend/atb-models/atb_llm/common_op_builders目录下面,每个算子都通过_libatb_torch._BaseOperation的方式调用ATB算子。

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。