走近深度学习,认识MoXing:优化器配置

举报
云上AI 发表于 2018/08/22 10:24:27 2018/08/22
【摘要】 本文为MoXing系列文章第五篇,主要介绍Optimizer、OptimizerWrapper。

用户可以使用mox.get_optimizer_fn来获取MoXing内置的Optimizer,也可以使用TensorFlow定义或由用户自己实现的Optimizer。此外,MoXing还提供了OptimizerWrapper的用法。


1  基础Optimizer

使用内置OPT:

mox.run(...,

        optimizer_fn=mox.get_optimizer_fn('momentum', learning_rate=0.01, momentum=0.9),

        ...)

使用TF定义的OPT:

mox.run(...,

        optimizer_fn=lambda: tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),

        ...)

 

使用自定义的OPT:

def optimizer_fn():

  ...

return my_optimizer()

mox.run(...,

        optimizer_fn=optimizer_fn,

        ...)


mox.run中optimizer_fn需要传入的是一个返回optimizer的函数,而不是一个optimizer,以下代码的使用方式是错误的:

mox.run(...,

          optimizer_fn=tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),

          ...)

此时可能会出现如下错误信息:

TypeError: 'MomentumOptimizer' object is not callable

只需要在optimizer上加上lambda表达式就能正确

mox.run(...,

          optimizer_fn=lambda: tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),

          ...)


2  封装器OptimizerWrapper

使用mox.get_optimizer_wrapper_fn可以获取Optimizer的高级应用方法。OptimizerWrapper是对optimizer的一层封装,类似tf.train.SyncReplicasOptimizer的用法。并且在允许的范围内,可以使用多层封装。样例代码如下。

使用Batch Gradient Descent,基础OPT为Momentum,每经过8个step的周期提交一次累计梯度。

def optimizer_fn():

opt = mox.get_optimizer_fn('momentum', learning_rate=lr, momentum=0.9)()

opt = mox.get_optimizer_wrapper_fn('batch_gradients', opt, num_batches=8, sync_on_apply=True)()

return opt

mox.run(..., optimizer_fn=optimizer_fn, ...)


当遇到输出信息如下:

WARNING:tensorflow:Using OptimizerWrapper when sync_replicas is True may cause performance loss.

这并不是一个错误,大多数OptimizerWrapper都要求在异步模式下使用,如Batch Gradient Descent当没有到通信周期时,分布式的每个worker是异步的,而到了通信周期时,是通过Optimizer本身的sync_on_apply=True参数来做同步,所以需要设置运行参数--sync_replicas=False来启动一个异步分布式运行,才能发挥Batch Gradient Descent的性能优势。另外类似EASGD这类Optimizer本身就要求在异步模型下运行。

复现bact_size=32k训练ResNet-50,当节点数量不够时,可以通过Batch Gradient Descent等效增加每个节点的batch_size,并且使用LARS训练,此时将涉及3层Optimizer的封装:

def optimizer_fn():

lr = config_lr(...)

opt = mox.get_optimizer_fn('momentum', learning_rate=lr, momentum=0.9)()

opt = mox.get_optimizer_wrapper_fn('lars', opt, ratio=0.001, weight_decay=0.0001)()

opt = mox.get_optimizer_wrapper_fn('batch_gradients', opt, num_batches=8, sync_on_apply=True)()

注意:

·         当run_mode为mox.ModeKeys.TRAIN时,optimizer_fn必须填充。

·         当run_mode为mox.ModeKeys.EVAL时,optimizer_fn不需要填充。




MoXing系列文章下期预告:运行与公共组件。


【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。