Keras中的MultiStepLR

举报
悲恋花丶无心之人 发表于 2021/02/03 01:14:15 2021/02/03
5.3k+ 0 0
【摘要】 Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的: 1.代码 from tensorflow.python.keras.callbacks import Callbackfrom tensorflow.python.keras import backend as Kimport numpy as npimport argpar...

Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:

1.代码


      from tensorflow.python.keras.callbacks import Callback
      from tensorflow.python.keras import backend as K
      import numpy as np
      import argparse
      parser = argparse.ArgumentParser()
      parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
      parser.add_argument('--lr_decay_factor', type=float, default=0.1)
      args, _ = parser.parse_known_args()
      def get_lr_scheduler(args):
       lr_scheduler = MultiStepLR(args=args)
      return lr_scheduler
      class MultiStepLR(Callback):
      """Learning rate scheduler.
       Arguments:
       args: parser_setting
       verbose: int. 0: quiet, 1: update messages.
       """
      def __init__(self, args, verbose=0):
       super(MultiStepLR, self).__init__()
       self.args = args
       self.steps = args.lr_decay_epochs
       self.factor = args.lr_decay_factor
       self.verbose = verbose
      def on_epoch_begin(self, epoch, logs=None):
      if not hasattr(self.model.optimizer, 'lr'):
      raise ValueError('Optimizer must have a "lr" attribute.')
       lr = self.schedule(epoch)
      if not isinstance(lr, (float, np.float32, np.float64)):
      raise ValueError('The output of the "schedule" function '
      'should be float.')
       K.set_value(self.model.optimizer.lr, lr)
       print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
      if self.verbose > 0:
       print('\nEpoch %05d: MultiStepLR reducing learning '
      'rate to %s.' % (epoch + 1, lr))
      def schedule(self, epoch):
       lr = K.get_value(self.model.optimizer.lr)
      for i in range(len(self.steps)):
      if epoch == self.steps[i]:
       lr = lr * self.factor
      return lr
  
 

2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)


      callbacks = []
      lr_scheduler = get_lr_scheduler(args=args)
      callbacks.append(lr_scheduler)
      ...
      model.fit_generator(train_generator,
       steps_per_epoch=train_generator.samples // args.batch_size,
       validation_data=test_generator,
       validation_steps=test_generator.samples // args.batch_size,
       workers=args.num_workers,
       callbacks=callbacks,  # 你的callbacks, 包含了lr_scheduler
       epochs=args.epochs,
       )
  
 

大家可以拿去用~

文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。

原文链接:nickhuang1996.blog.csdn.net/article/details/103645204

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

作者其他文章

评论(0

抱歉,系统识别当前为高风险访问,暂不支持该操作

    全部回复

    上滑加载中

    设置昵称

    在此一键设置昵称,即可参与社区互动!

    *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

    *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。