神经网络参数优化更新的步骤——tensorflow实现线性回归

举报
Y_K_C 发表于 2021/10/10 20:50:21 2021/10/10
【摘要】 基本思路本文是简单地体验一下神经网络参数更新的流程,因此不涉及激活函数和Drop_out等知识点。首先利用高斯分布随机生成2000个点,这2000个点围绕某条已知的直线,再初始化权重参数w和偏移量b,根据w和b计算出预测值,再与真实值比较计算出损失函数(采用均方误差作为指标),使用梯度下降的优化方法更新参数使得损失函数最小化,最后让整个线性回归模型训练500次即可。 代码及流程本例使用Mo...

基本思路

本文是简单地体验一下神经网络参数更新的流程,因此不涉及激活函数和Drop_out等知识点。
首先利用高斯分布随机生成2000个点,这2000个点围绕某条已知的直线,再初始化权重参数w和偏移量b,根据w和b计算出预测值,再与真实值比较计算出损失函数(采用均方误差作为指标),使用梯度下降的优化方法更新参数使得损失函数最小化,最后让整个线性回归模型训练500次即可。

代码及流程

本例使用ModelArts进行模型部署,所用的框架是tensorflow。

1.进入ModelArts控制台和创建项目

进入ModelArts控制台(控制台选择华北-北京四,北京四有免费资源)后选择开发环境,点击进入Notebook,点击创建新项目(若进入的是新版的Notebook,请点击回到旧版)
项目名称自拟,工作环境选择第一个(只要包含tensorflow的1.x.x版本即可),类型选择CPU2核8G,资源池选择公共资源池,储存配置选择云硬盘。

创建好Notebook后,新建一个文件(即点击new),选择tensorflow-1.13.1,创建成功后就可以开始写代码啦。

2.代码

import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf

#随机点的生成
x_set = []
y_set = []
for i in range(2000):
    x = np.random.normal(0.0,0.5)
    x_set.append(x)
    y = 0.03*x+0.05+np.random.normal(0.0,0.02)
    y_set.append(y)
plt.scatter(x_set,y_set,c='r')
plt.show()

#设置权重参数w,由于是点,所以设置为一维
W = tf.Variable(tf.random.normal([1],mean = 0,stddev = 0.01),name = "W")
#设置偏移量b
b = tf.Variable(tf.zeros([1],tf.float32))
#预测值
y_p = W*x_set+b
#计算损失函数(均方误差)
loss = tf.reduce_mean(tf.square(y_set-y_p))
#梯度下降优化参数
optimizer = tf.train.GradientDescentOptimizer(0.1)#0.1是学习率
train = optimizer.minimize(loss)

#初始化变量
init = tf.compat.v1.global_variables_initializer()
sess = tf.compat.v1.Session()
sess.run(init)

print("W=",sess.run(W),"b=",sess.run(b),"loss=",sess.run(loss))
print("Start training\n")
#训练500for step in range(500):
    sess.run(train)
    print("W=",sess.run(W),"b=",sess.run(b),"loss=",sess.run(loss))
    
print("Training ending\n")
print("W=",sess.run(W),"b=",sess.run(b))

随机生成的点:
image.png

训练后结果:
image.png
训练结果的w和b还是比较接近预设的0.03和0.05的。

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。