DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测
【摘要】 DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测
导读 计算图在神经网络算法中的作用。计算图的节点是由局部计算构成的。局部计算构成全局计算。计算图的正向传播进行一般的计算。通过计算图的反向传播,可以计算各个节点的导数。
目录
输出结果
设计思路
核心代码
...
DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测
导读
计算图在神经网络算法中的作用。计算图的节点是由局部计算构成的。局部计算构成全局计算。计算图的正向传播进行一般的计算。通过计算图的反向传播,可以计算各个节点的导数。
目录
输出结果
设计思路
核心代码
-
-
-
-
class TwoLayerNet:
-
-
def __init__(self, input_size, hidden_size, output_size, weight_init_std = 0.01):
-
-
self.params = {}
-
self.params['W1'] = weight_init_std * np.random.randn(input_size, hidden_size)
-
self.params['b1'] = np.zeros(hidden_size)
-
self.params['W2'] = weight_init_std * np.random.randn(hidden_size, output_size)
-
self.params['b2'] = np.zeros(output_size)
-
-
self.layers = OrderedDict()
-
self.layers['Affine1'] = Affine(self.params['W1'], self.params['b1'])
-
self.layers['Relu1'] = Relu()
-
self.layers['Affine2'] = Affine(self.params['W2'], self.params['b2'])
-
-
self.lastLayer = SoftmaxWithLoss()
-
-
def predict(self, x):
-
for layer in self.layers.values():
-
x = layer.forward(x)
-
-
return x
-
-
# x:输入数据, t:监督数据
-
def loss(self, x, t):
-
y = self.predict(x)
-
return self.lastLayer.forward(y, t)
-
-
def accuracy(self, x, t):
-
y = self.predict(x)
-
y = np.argmax(y, axis=1)
-
if t.ndim != 1 : t = np.argmax(t, axis=1)
-
-
accuracy = np.sum(y == t) / float(x.shape[0])
-
return accuracy
-
-
-
def gradient(self, x, t):
-
self.loss(x, t)
-
-
-
dout = 1
-
dout = self.lastLayer.backward(dout)
-
-
layers = list(self.layers.values())
-
layers.reverse()
-
for layer in layers:
-
dout = layer.backward(dout)
-
-
grads = {}
-
grads['W1'], grads['b1'] = self.layers['Affine1'].dW, self.layers['Affine1'].db
-
grads['W2'], grads['b2'] = self.layers['Affine2'].dW, self.layers['Affine2'].db
-
-
return grads
相关文章
DL之DNN:自定义2层神经网络TwoLayerNet模型(封装为层级结构)利用MNIST数据集进行训练、预测
文章来源: yunyaniu.blog.csdn.net,作者:一个处女座的程序猿,版权归原作者所有,如需转载,请联系作者。
原文链接:yunyaniu.blog.csdn.net/article/details/88959569
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)