理解IR:从MNIST推理看计算图生成

举报
黄生 发表于 2025/10/01 19:51:02 2025/10/01
【摘要】 让我们以经典的MNIST手写数字识别模型为例,一起探索如何生成和查看IR文件。如果您想要了解完整的模型训练流程,建议参考官方文档中的详细教程(教程 (2.6.0) > 快速入门(https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/quick_start.html)。在这里,我们假设您已经完成了模型训练,并成功生成了最终的模型文件。...

让我们以经典的MNIST手写数字识别模型为例,一起探索如何生成和查看IR文件。如果您想要了解完整的模型训练流程,建议参考官方文档中的详细教程(教程 (2.6.0) > 快速入门(https://www.mindspore.cn/tutorials/zh-CN/r2.6.0/beginner/quick_start.html)。在这里,我们假设您已经完成了模型训练,并成功生成了最终的模型文件。有了这个训练好的模型文件,我们就可以进入推理阶段了。

import mindspore #导入主框架,提供张量操作等基础功能
from mindspore import nn #神经网络模块:包含各种网络层(Conv2d, Linear)、激活函数、损失函数等
from mindspore.dataset import vision, transforms #数据预处理:vision:图像特定变换(Resize, Normalize, RandomCrop) transforms:通用数据变换
from mindspore.dataset import MnistDataset #数据集加载:提供 MNIST 手写数字数据集的直接接口

'''
import os
mindspore.set_context(mode=mindspore.GRAPH_MODE)
os.environ['MS_DEV_SAVE_GRAPHS'] = '3'
os.environ['MS_DEV_SAVE_GRAPHS_PATH'] = './ir'
'''

# Define model
class Network(nn.Cell):
    def __init__(self): #定义网络层结构,不需要输入数据
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x): #self 指向实例本身,x 是输入数据,进行实际计算。前向传播方法,每次推理/训练时调用
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

test_dataset = MnistDataset('MNIST_Data/test')

def datapipe(dataset, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)), #output = (input - mean) / std 
        #减去均值 0.1307 → 数据中心化 除以标准差 0.3081 → 数据标准化 最终数据分布:均值为0,标准差为1
        #Normalize 函数的 mean 和 std 参数需要接收每个通道的均值和标准差,即使只有一个通道值也要用元组形式。
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)

    dataset = dataset.map(image_transforms, 'image')
    dataset = dataset.map(label_transform, 'label')
    dataset = dataset.batch(batch_size)
    return dataset

# Map vision transforms and batch dataset
test_dataset = datapipe(test_dataset, 64)

# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint("model.ckpt")
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

model.set_train(False)
for data, label in test_dataset:
    pred = model(data)
    predicted = pred.argmax(1)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

推理过程就像是把一道精心调试好的配方付诸实践——我们不需要再调整配料比例,只需要按照既定的步骤操作就能得到预测结果。以上代码同时会将模型转换为中间表示,也就是我们所说的IR文件。以数字下划线开头的 IR 文件记录了前端编译过程中计算图在不同阶段的中间表达形式。后端优化过程中也会生成一系列 IR 文件,例如以 hwopt 开头的文件,它们更贴近底层硬件优化,通常对框架开发者更为重要。对于大多数非框架开发人员,只需关注名为 graph_build_[图序号]_[IR文件序号].ir 的文件,它代表了前后端全部优化完成后的最终计算图。

由于 IR 文件数量较多,为便于查看,我们可以过滤掉以 hwopt_opt_ 开头的文件。例如可以使用如下命令进行筛选,并按照文件名的数字后缀排序,从而清晰展示编译过程中关键图结构的变化轨迹。

find ./ -name '*.ir' ! -name 'hwopt_*' ! -name 'opt_*' | awk -F "_" '{print $NF " ---> " $0}' | sort -n


编译流程从 bootstrap 阶段开始,紧接着进入 type_inference 阶段,这一环节承担了类型推导与符号解析的任务。系统会递归地遍历入口函数,解析其中对其他函数或对象的引用,并推断出所有节点的数据类型与张量形状。这一过程不仅为后续优化打下类型基础,还能在早期捕获诸如语法不支持或符号未定义等常见错误,为开发者提供及时的反馈,避免问题向后累积。

随后是 optimize 阶段,这一阶段主要进行硬件无关的优化,包括自动微分和自动并行等关键功能的展开。该阶段内部还可进一步划分为多个子阶段,每个子阶段结束后都会生成一份以 opt_pass_[序号] 为前缀的 IR 文件。对于大多数应用开发者而言,无需深入细节,了解其整体作用即可。

validate 阶段,系统会对编译生成的计算图做最终校验。如果图中仍包含仅供内部使用的临时算子,编译过程将在此报错并终止,确保输出图的纯净与可用性。

接下来,task_emit 阶段负责将优化后的计算图传递给后端处理模块。可以将其视为“交图仪式”,前端任务到此基本完成。

最后是 execute 阶段,该阶段标志着前端编译流程的结束,并启动图的执行过程。这一阶段所保存的 IR 图,即为前端编译的最终产物。


我们来具体看上面提供的这几行代码:

  1. PrimFunc_Flatten:拓扑变换的抽象
    PrimFunc_Flatten(%para1_x) 将形状为 (64, 1, 28, 28) 的输入张量(可以理解为64张1x28x28的图片)变换为 (64, 784)。这里的“Flatten”操作是一个高度抽象的原语。它并不关心具体是什么数据,它只声明一个意图:“把后三个维度展平”。在编译器的后端,这个原语可能会被实现为一个无操作的“视图”(View),仅仅改变张量的步长和形状描述,而不实际移动数据;也可能在特定硬件上,被转换为一次实际的内存重排操作。这体现了编译器将高级操作与底层实现解耦的思想。
  2. Load:内存系统的窥视
    Load(%para2_dense...weight, UMonad[U]) 揭示了静态图编译中如何处理状态。%para2_dense...weight 是一个权重参数,在编译时它通常作为一个常量(或可训练的变量)被存储。Load 操作的含义是“从某个存储位置读取数据到计算单元”。更有趣的是 UMonad[U],这是一个“单子”,是函数式编程中的一个概念,在这里用于安全地管理和追踪副作用(Side Effect)。在纯函数式的计算图中,Load 和后续可能会出现的 Store 是典型的副作用操作(改变了或依赖于系统状态,因为这里是推理,所以只有Load)。通过引入Monad,编译器可以在保持图计算纯度的同时,清晰地表达这些必要的IO操作,确保执行的正确顺序,防止读写冲突。
  3. PrimFunc_MatMul:计算核心的宣言
    PrimFunc_MatMul(%0, %1, Bool(0), Bool(1)) 是计算的核心——矩阵乘法。它取展平后的输入 %0 和加载的权重 %1。后面的两个布尔值参数通常用于控制一些高级选项,比如第一个可能是 transpose_a,第二个是 transpose_b。这里 Bool(0), Bool(1) 很可能意味着“不转置第一个矩阵,但转置第二个矩阵”。这正好符合数学:(64, 784) 乘以 (512, 784) 的转置,即 (64, 784) * (784, 512),得到 (64, 512) 的输出。这个原语是性能优化的重中之重,在不同的硬件平台(CPU、GPU、NPU)上,它会被 lowering( lowering 是编译器的术语,指将高级IR转换为更低级、更具体的IR或指令的过程) 成截然不同的指令序列:可能是调用高度优化的cuBLAS/cuDNN库,可能是展开成一系列针对特定计算单元的SIMD指令,也可能是映射到脉动阵列上的数据流。

这几行IR是一个精心设计的抽象层,它既清晰地表达了深度学习模型的计算意图,又为后端五花八门的硬件靶点留下了充分的优化空间。

最后我们看一下静态图的图像展示,以 22_execute_0694.png 为例:

【声明】本内容来自华为云开发者社区博主,不代表华为云及华为云开发者社区的观点和立场。转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息,否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。