pytorch不同的层设置不同的学习率

举报
AI浩 发表于 2021/12/23 01:29:18 2021/12/23
1.5k+ 0 0
【摘要】   import torchfrom torch import nn, optimfrom torch.autograd import Variableimport numpy as npimport matplotlib.pyplot as plt x_train = np.array([[3.3], [4.4], [5....
 

      import torch
      from torch import nn, optim
      from torch.autograd import Variable
      import numpy as np
      import matplotlib.pyplot as plt
      x_train = np.array([[3.3], [4.4], [5.5], [6.71], [6.93], [4.168],
                          [9.779], [6.182], [7.59], [2.167], [7.042],
                          [10.791], [5.313], [7.997], [3.1]], dtype=np.float32)
      y_train = np.array([[1.7], [2.76], [2.09], [3.19], [1.694], [1.573],
                          [3.366], [2.596], [2.53], [1.221], [2.827],
                          [3.465], [1.65], [2.904], [1.3]], dtype=np.float32)
      x_train = torch.from_numpy(x_train)
      y_train = torch.from_numpy(y_train)
      # Linear Regression Model
      class LinearRegression(nn.Module):
         def __init__(self):
             super(LinearRegression, self).__init__()
              self.linear1 = nn.Linear(1, 5)  # input and output is 1 dimension
              self.linear2 = nn.Linear(5, 1)
         def forward(self, x):
              out = self.linear1(x)
              out = self.linear2(out)
             return out
      model = LinearRegression()
      print(model.linear1)
      # 微调:自定义每一层的学习率
      # 定义loss和优化函数
      criterion = nn.MSELoss()
      optimizer = optim.SGD(
          [{"params": model.linear1.parameters(), "lr": 0.01},
           {"params": model.linear2.parameters()}],
          lr=0.02)
  
 

文章来源: wanghao.blog.csdn.net,作者:AI浩,版权归原作者所有,如需转载,请联系作者。

原文链接:wanghao.blog.csdn.net/article/details/114600211

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

作者其他文章

评论(0

抱歉,系统识别当前为高风险访问,暂不支持该操作

    全部回复

    上滑加载中

    设置昵称

    在此一键设置昵称,即可参与社区互动!

    *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

    *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。