MindSpore的模型训练

举报
剑指南天 发表于 2026/06/15 19:34:49 2026/06/15
【摘要】 本文通过 MindSpore 的API完整实现一个深度学习模型,涉及到模型训练的各个环节,对学习 MindSpore 的模型训练很有作用。

1.概述

模型训练一般包括:①数据的整理;②构建数据集;③构建模型;④项目配置;⑤模型训练;⑥模型推理;⑦模型验证。

2. 项目配置

关于文件保存的配置一般包括原始数据存储位置配置,处理之后数据的保存位置,模型保存位置,日志文件位置和文件名等配置。另一方面主要是模型的配置。项目配置需要根据任务,便宜设置。

from pathlib import Path

import mindspore
from download import download
from mindspore import nn
from mindspore.dataset import MnistDataset
from mindspore.dataset import vision, transforms

# 项目配置
ROOT_DIR = Path(__file__).parent.parent
# 原始文件保存位置
RAW_DATA_DIR = ROOT_DIR / 'data' / 'raw'
RAW_DATA_FOR_TRAIN = 'MNIST_Data/train'
RAW_DATA_FOR_TEST = 'MNIST_Data/test'
# 原始文件下载链接
RAW_DATA_URL = "https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/" \
               "notebook/datasets/MNIST_Data.zip"
# 模型文件保存位置
MODEL_DIR = ROOT_DIR / 'models'
BEST_MODEL_NAME = 'best_model.ckpt'
# 模型训练超参数
EPOCHS = 30
BATCH_SIZE = 64
LEARNING_RATE = 1e-2

3. 数据的整理

模型的数据来源是多种多样的,数据质量也是千差万别。比如网络上爬取的数据需要经过下面的预处理,才可以作为训练数据。

模型对数据的要求也是多种多样的,比如大模型微调数据应该由人工按照模板整理数据。

本次实验数据是公共数据,是比较干净的数据,不需要清洗,直接下载即可。

# 收集数据
download(RAW_DATA_URL, str(RAW_DATA_DIR), kind="zip", replace=True)

4. 构建数据集

实验数据是公共数据集,已经符合数据集要求,只需要加载,转换,无需构建。

def datapipe(path, batch_size):
    image_transforms = [
        vision.Rescale(1.0 / 255.0, 0),
        vision.Normalize(mean=(0.1307,), std=(0.3081,)),
        vision.HWC2CHW()
    ]
    label_transform = transforms.TypeCast(mindspore.int32)
    # 从文件加载数据
    dataset = MnistDataset(path)
    # 转换训练数据
    dataset = dataset.map(image_transforms, 'image')
    # 转化训练数据标签
    dataset = dataset.map(label_transform, 'label')
    # 数据分批
    dataset = dataset.batch(batch_size)
    return dataset

# 训练数据
train_dataset = datapipe(str(RAW_DATA_DIR/RAW_DATA_FOR_TRAIN), batch_size=BATCH_SIZE)
# 测试数据
test_dataset = datapipe(str(RAW_DATA_DIR/RAW_DATA_FOR_TEST), batch_size=BATCH_SIZE)

5. 网络构建

神经网络模型由神经网络层和Tensor操作构成,mindspore.nn 提供了常见神经网络层的实现。可以参考 https://www.mindspore.cn/docs/zh-CN/stable/note/api_mapping/pytorch_api_mapping.html,将 pytorch 代码转化为MindSpore 代码。

class Network(nn.Cell):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()
        self.dense_relu_sequential = nn.SequentialCell(
            nn.Dense(28*28, 512),
            nn.ReLU(),
            nn.Dense(512, 512),
            nn.ReLU(),
            nn.Dense(512, 10)
        )

    def construct(self, x):
        x = self.flatten(x)
        logits = self.dense_relu_sequential(x)
        return logits

6. 模型训练、测试和保存表现最优模型

# 初始化模型
model = Network()
# 声明损失函数
loss_fn = nn.CrossEntropyLoss()
# 声明优化器
optimizer = nn.SGD(model.trainable_params(), learning_rate=LEARNING_RATE)

# 获得微分函数,用于计算梯度
def forward_fn(data, label):
    logits = model(data)
    loss = loss_fn(logits, label)
    return loss, logits


# has_aux参数设置为True时,可以自动实现添加stop_gradient的功能,满足返回辅助数据的同时不影响梯度计算的效果
grad_fn = mindspore.value_and_grad(forward_fn, None, optimizer.parameters, has_aux=True)

# Define function of one-step training
def train_step(data, label):
    (loss, _), grads = grad_fn(data, label)
    optimizer(grads)
    return loss
# 训练一个epoch
def train_loop(model, dataset):
    size = dataset.get_dataset_size()
    model.set_train()
    for batch, (data, label) in enumerate(dataset.create_tuple_iterator()):
        loss = train_step(data, label)
        if batch % 100 == 0:
            loss, current = loss.asnumpy(), batch
            print(f"loss: {loss:>7f}  [{current:>3d}/{size:>3d}]")

# 推理
def model_infer(model,data):
    pred = model(data)
    return pred.argmax(1)

# 测试
def test_loop(model, dataset):
    model.set_train(False)
    total, correct = 0, 0
    for data, label in dataset.create_tuple_iterator():
        pred_label = model_infer(model,data)
        correct += (pred_label == label).asnumpy().sum()
        total += len(data)
    correct /= total
    print(f"Test: \n Accuracy: {(100 * correct):>0.1f}% \n")
    return correct

correct_rate = 0.0
for t in range(EPOCHS):
    print(f"Epoch {t + 1}\n-------------------------------")
    train_loop(model, train_dataset)
    correct = test_loop(model, test_dataset)
    if correct > correct_rate:
        mindspore.save_checkpoint(model, str(MODEL_DIR / BEST_MODEL_NAME))
        print(f"Epoch {t + 1}: 保存模型成功!\n")
        correct_rate = correct

7. 模型的加载和推理

# Instantiate a random initialized model
model = Network()
# Load checkpoint and load parameter to model
param_dict = mindspore.load_checkpoint(str(MODEL_DIR / BEST_MODEL_NAME))
param_not_load, _ = mindspore.load_param_into_net(model, param_dict)
print(param_not_load)

model.set_train(False)
for data, label in test_dataset:
    predicted = model_infer(model,data)
    print(f'Predicted: "{predicted[:10]}", Actual: "{label[:10]}"')
    break

8. 总结

本文通过 MindSpore 的API完整实现一个深度学习模型,涉及到模型训练的各个环节,对学习 MindSpore 的模型训练很有作用。

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。