[热门活动]
MindSpore版本更新体验——MindSpore 1.2.0-rc1
一、平台环境
CPU:lntel(R) Core(TM) i7-10875H CPU@ 2.30GHz
内存:24GB
操作系统:Win 10 20H2
二、安装Mindspore 1.2.0-rc1
安装验证
查看MindSpore版本
python
import mindspore
print(mindspore.__version__)
三、基于Mindspore 1.2.0-rc1在本地进行模型转换
简单的导出一个LeNet网络的MindIR格式模型
MindSpore官网为我们提供了LeNet的Checkpoint文件,提供了不同版本的:https://download.mindspore.cn/model_zoo/official/cv/lenet/
*Checkpoint • 采用了Protocol Buffers格式,存储了网络中所有的参数值。一般用于训练任务中断后恢复训练,或训练后的微调(Fine Tune)任务。在这里我选择了CPU
1.定义网络
import mindspore.nn as nnfrom mindspore.common.initializer import Normalclass LeNet5(nn.Cell):
"""
Lenet network structure
"""
#define the operator required
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
#use the preceding operators to construct networks
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
2.导出模型为MindIR格式模型
import time
import mindspore.nn as nn
import numpy as np
from datetime import datetime
from mindspore.common.initializer import Normal
from mindspore import Tensor, export, load_checkpoint, load_param_into_net
lenet = LeNet5()
# 返回模型的参数字典
param_dict = load_checkpoint("./lenet.ckpt")
# 加载参数到网络
load_param_into_net(lenet, param_dict)
input = np.random.uniform(0.0, 1.0, size=[32, 1, 32, 32]).astype(np.float32)
# 以指定的名称和格式导出文件
export(lenet, Tensor(input), file_name='lenet.mindir', file_format='MINDIR',)
t = datetime.now().strftime('%Y-%m-%d %H:%M:%S')
print(" ")
print("============== Model conversion succeeded ==============")
print("The current Mindspore version is:",mindspore.__version__)
print(t)
个人邮箱:chunjcsx20@vip.qq.com