MindIE DeepSeek MTP特性定位策略

举报
AI布道Mr_jin 发表于 2025/06/23 15:42:51 2025/06/23
【摘要】 最近MindIE开始支持DeepSeek MTP(multi token prediction)特性了,用于推理加速。但是有些开发者打开MTP开关后,没有发现明显的性能提升。这篇文章提供一种定位策略。原理很简单,就是看一下每次MTP推理后,模型是输出1个token还是多个token。由于MTP的token处理算法是用python实现的,所以可以在镜像的python代码中添加日志,可以在2个地...

最近MindIE开始支持DeepSeek MTP(multi token prediction)特性了,用于推理加速。但是有些开发者打开MTP开关后,没有发现明显的性能提升。这篇文章提供一种定位策略。

原理很简单,就是看一下每次MTP推理后,模型是输出1个token还是多个token。由于MTP的token处理算法是用python实现的,所以可以在镜像的python代码中添加日志,可以在2个地方加日志查看MTP的采信率(也就是verify的成功比例)。

首先可以在MindIE镜像的/usr/local/lib/python3.11/site-packages/mindie_llm/text_generator/plugins/mtp/mtp_plugin.py 路径中找到verify_greedy_one_batch()函数,然后打印相关参数。

    @staticmethod
    def verify_greedy_one_batch(verify_guess_tokens, next_guess_tokens):
        gg = 0
        for eg, guess_tokens in enumerate(verify_guess_tokens):
            correct = next_guess_tokens[eg]
            guess = guess_tokens
            if guess != correct:
                break
            gg += 1

或者在 /usr/local/lib/python3.11/site-packages/mindie_llm/text_generator/plugins/plugin_manager.py里面的init函数和generate_token函数增加如下代码:

from ...utils.log.logging import logger  # 新增
class PluginManager:
    def __init__(...):
        self.all_token_num = 0
        self.all_decode_count = 0
        self.all_prefill_count = 0
        ...
    @timer.track_time_async('generate_token')
    def generate_token(self, input_metadata: InputMetadata):
        ...
        span_end(prof)

        if not input_metadata.is_dummy_batch:
            if not input_metadata.is_prefill:
                for i in range(input_metadata.batch_size):
                    next_tokens = generation_output.token_ids[i]
                    if -1 in next_tokens:
                        first_neg_one_index = np.argmax(next_tokens == -1)
                        next_tokens = next_tokens[:first_neg_one_index]
                    self.all_token_num += len(next_tokens)
                    self.all_decode_count += 1
            logger.error(f"self.all_token_num is {self.all_token_num}, self.all_decode_count is {self.all_decode_count}, self.all_prefill_count is {self.all_prefill_count}. Ratio is {self.all_token_num / self.all_decode_count}")

            if input_metadata.is_prefill:
                for _ in range(input_metadata.batch_size):
                    self.all_prefill_count += 1
                    
        generation_output.trace_ids = trace_ids
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。