Keras中的MultiStepLR
【摘要】 Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:
1.代码
from tensorflow.python.keras.callbacks import Callbackfrom tensorflow.python.keras import backend as Kimport numpy as npimport argpar...
Keras中没有多步调整学习率(MultiStepLR)的调度器,但是博主这里提供一个自己写的:
1.代码
from tensorflow.python.keras.callbacks import Callback
from tensorflow.python.keras import backend as K
import numpy as np
import argparse
parser = argparse.ArgumentParser()
parser.add_argument('--lr_decay_epochs', type=list, default=[2, 5, 7], help="For MultiFactorScheduler step")
parser.add_argument('--lr_decay_factor', type=float, default=0.1)
args, _ = parser.parse_known_args()
def get_lr_scheduler(args):
lr_scheduler = MultiStepLR(args=args)
return lr_scheduler
class MultiStepLR(Callback):
"""Learning rate scheduler.
Arguments:
args: parser_setting
verbose: int. 0: quiet, 1: update messages.
"""
def __init__(self, args, verbose=0):
super(MultiStepLR, self).__init__()
self.args = args
self.steps = args.lr_decay_epochs
self.factor = args.lr_decay_factor
self.verbose = verbose
def on_epoch_begin(self, epoch, logs=None):
if not hasattr(self.model.optimizer, 'lr'):
raise ValueError('Optimizer must have a "lr" attribute.')
lr = self.schedule(epoch)
if not isinstance(lr, (float, np.float32, np.float64)):
raise ValueError('The output of the "schedule" function '
'should be float.')
K.set_value(self.model.optimizer.lr, lr)
print("learning rate: {:.7f}".format(K.get_value(self.model.optimizer.lr)).rstrip('0'))
if self.verbose > 0:
print('\nEpoch %05d: MultiStepLR reducing learning '
'rate to %s.' % (epoch + 1, lr))
def schedule(self, epoch):
lr = K.get_value(self.model.optimizer.lr)
for i in range(len(self.steps)):
if epoch == self.steps[i]:
lr = lr * self.factor
return lr
2.调用(callbacks里append这个lr_scheduler,fit_generator里callbacks传入这个变量)
callbacks = []
lr_scheduler = get_lr_scheduler(args=args)
callbacks.append(lr_scheduler)
...
model.fit_generator(train_generator,
steps_per_epoch=train_generator.samples // args.batch_size,
validation_data=test_generator,
validation_steps=test_generator.samples // args.batch_size,
workers=args.num_workers,
callbacks=callbacks, # 你的callbacks, 包含了lr_scheduler
epochs=args.epochs,
)
大家可以拿去用~
文章来源: nickhuang1996.blog.csdn.net,作者:悲恋花丶无心之人,版权归原作者所有,如需转载,请联系作者。
原文链接:nickhuang1996.blog.csdn.net/article/details/103645204
【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
作者其他文章
评论(0)