走近深度学习,认识MoXing:优化器配置
用户可以使用mox.get_optimizer_fn来获取MoXing内置的Optimizer,也可以使用TensorFlow定义或由用户自己实现的Optimizer。此外,MoXing还提供了OptimizerWrapper的用法。
1 基础Optimizer
使用内置OPT:
mox.run(...,
optimizer_fn=mox.get_optimizer_fn('momentum', learning_rate=0.01, momentum=0.9),
...)
使用TF定义的OPT:
mox.run(...,
optimizer_fn=lambda: tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),
...)
使用自定义的OPT:
def optimizer_fn():
...
return my_optimizer()
mox.run(...,
optimizer_fn=optimizer_fn,
...)
mox.run中optimizer_fn需要传入的是一个返回optimizer的函数,而不是一个optimizer,以下代码的使用方式是错误的:
mox.run(...,
optimizer_fn=tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),
...)
此时可能会出现如下错误信息:
TypeError: 'MomentumOptimizer' object is not callable
只需要在optimizer上加上lambda表达式就能正确
mox.run(...,
optimizer_fn=lambda: tf.train.MomentumOptimizer(learning_rate=0.01, momentum=0.9),
...)
2 封装器OptimizerWrapper
使用mox.get_optimizer_wrapper_fn可以获取Optimizer的高级应用方法。OptimizerWrapper是对optimizer的一层封装,类似tf.train.SyncReplicasOptimizer的用法。并且在允许的范围内,可以使用多层封装。样例代码如下。
使用Batch Gradient Descent,基础OPT为Momentum,每经过8个step的周期提交一次累计梯度。
def optimizer_fn():
opt = mox.get_optimizer_fn('momentum', learning_rate=lr, momentum=0.9)()
opt = mox.get_optimizer_wrapper_fn('batch_gradients', opt, num_batches=8, sync_on_apply=True)()
return opt
mox.run(..., optimizer_fn=optimizer_fn, ...)
当遇到输出信息如下:
WARNING:tensorflow:Using OptimizerWrapper when sync_replicas is True may cause performance loss.
这并不是一个错误,大多数OptimizerWrapper都要求在异步模式下使用,如Batch Gradient Descent当没有到通信周期时,分布式的每个worker是异步的,而到了通信周期时,是通过Optimizer本身的sync_on_apply=True参数来做同步,所以需要设置运行参数--sync_replicas=False来启动一个异步分布式运行,才能发挥Batch Gradient Descent的性能优势。另外类似EASGD这类Optimizer本身就要求在异步模型下运行。
复现bact_size=32k训练ResNet-50,当节点数量不够时,可以通过Batch Gradient Descent等效增加每个节点的batch_size,并且使用LARS训练,此时将涉及3层Optimizer的封装:
def optimizer_fn():
lr = config_lr(...)
opt = mox.get_optimizer_fn('momentum', learning_rate=lr, momentum=0.9)()
opt = mox.get_optimizer_wrapper_fn('lars', opt, ratio=0.001, weight_decay=0.0001)()
opt = mox.get_optimizer_wrapper_fn('batch_gradients', opt, num_batches=8, sync_on_apply=True)()
注意:
· 当run_mode为mox.ModeKeys.TRAIN时,optimizer_fn必须填充。
· 当run_mode为mox.ModeKeys.EVAL时,optimizer_fn不需要填充。
MoXing系列文章下期预告:运行与公共组件。
- 点赞
- 收藏
- 关注作者
评论(0)