mxnet优化器 SGD_GC

举报
风吹稻花香 发表于 2021/06/05 00:01:26 2021/06/05
【摘要】 11年it研发经验,从一个会计转行为算法工程师,学过C#,c++,java,android,php,go,js,python,CNN神经网络,四千多篇博文,三千多篇原创,只为与你分享,共同成长,一起进步,关注我,给你分享更多干货知识! mxnet优化器 sgd_gc代码: 原文:https://github.com/mnikitin/Gradient-Centraliza...

11年it研发经验,从一个会计转行为算法工程师,学过C#,c++,java,android,php,go,js,python,CNN神经网络,四千多篇博文,三千多篇原创,只为与你分享,共同成长,一起进步,关注我,给你分享更多干货知识!

mxnet优化器 sgd_gc代码:

原文:https://github.com/mnikitin/Gradient-Centralization

调用代码:


  
  1. import optimizer
  2. opt_params = {'learning_rate': 0.001}
  3. sgd_gc = optimizer.SGDGC(gc_type='gc', **opt_params)
  4. sgd_gcc = optimizer.SGDGC(gc_type='gcc', **opt_params)
  5. adam_gc = optimizer.AdamGC(gc_type='gc', **opt_params)
  6. adam_gcc = optimizer.AdamGC(gc_type='gcc', **opt_params)

 


  
  1. python3 mnist.py --optimizer sgdgc --gc-type gc --lr 0.1 --seed 42
  2. python3 mnist.py --optimizer adamgc --gc-type gcc --lr 0.001 --seed 42

 


  
  1. import mxnet as mx
  2. __all__ = []
  3. def _register_gc_opt():
  4. optimizers = dict()
  5. for name in dir(mx.optimizer):
  6. obj = getattr(mx.optimizer, name)
  7. if hasattr(obj, '__base__') and obj.__base__ == mx.optimizer.Optimizer:
  8. optimizers[name] = obj
  9. suffix = 'GC'
  10. def __init__(self, gc_type='gc', **kwargs):
  11. assert gc_type.lower() in ['gc', 'gcc']
  12. self.gc_ndim_thr = 1 if gc_type.lower() == 'gc' else 3
  13. super(self.__class__, self).__init__(**kwargs)
  14. def update(self, index, weight, grad, state):
  15. self._gc_update_impl(
  16. index, weight, grad, state,
  17. super(self.__class__, self).update)
  18. def update_multi_precision(self, index, weight, grad, state):
  19. self._gc_update_impl(
  20. index, weight, grad, state,
  21. super(self.__class__, self).update_multi_precision)
  22. def _gc_update_impl(self, indexes, weights, grads, states, update_func):
  23. # centralize gradients
  24. if isinstance(indexes, (list, tuple)):
  25. # multi index case: SGD optimizer
  26. for grad in grads:
  27. if len(grad.shape) > self.gc_ndim_thr:
  28. grad -= grad.mean(axis=tuple(range(1, len(grad.shape))), keepdims=True)
  29. else:
  30. # single index case: all other optimizers
  31. if len(grads.shape) > self.gc_ndim_thr:
  32. grads -= grads.mean(axis=tuple(range(1, len(grads.shape))), keepdims=True)
  33. # update weights using centralized gradients
  34. update_func(indexes, weights, grads, states)
  35. inst_dict = dict(
  36. __init__=__init__,
  37. update=update,
  38. update_multi_precision=update_multi_precision,
  39. _gc_update_impl=_gc_update_impl,
  40. )
  41. for k, v in optimizers.items():
  42. name = k + suffix
  43. inst = type(name, (v, ), inst_dict)
  44. mx.optimizer.Optimizer.register(inst)
  45. globals()[name] = inst
  46. __all__.append(name)
  47. _register_gc_opt()
  48. if __name__ == '__main__':
  49. import optimizer
  50. # opt_params = {'learning_rate': 0.001}
  51. # sgd_gc = optimizer.SGDGC(gc_type='gc', **opt_params)
  52. # sgd_gcc = optimizer.SGDGC(gc_type='gcc', **opt_params)
  53. # adam_gc = optimizer.AdamGC(gc_type='gc', **opt_params)
  54. # adam_gcc = optimizer.AdamGC(gc_type='gcc', **opt_params)

 

文章来源: blog.csdn.net,作者:网奇,版权归原作者所有,如需转载,请联系作者。

原文链接:blog.csdn.net/jacke121/article/details/116066322

【版权声明】本文为华为云社区用户转载文章,如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。