MindSpore强化强化学习:使用mindrl
【摘要】 安装MindRLpip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.1.0/Reinforcement/x86_64/mindspore_rl-0.7.0-py3-none-linux_x86_64.whlgit clone https://gitee.com/mindspore-lab/mindrl检验是否可...
安装MindRL
pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.1.0/Reinforcement/x86_64/mindspore_rl-0.7.0-py3-none-linux_x86_64.whl
git clone https://gitee.com/mindspore-lab/mindrl
检验是否可以正常使用。
cd mindrl/example/dqn
python train.py --episode 1000 --device_target GPU
那今天我们先在DQN的artPole-v0上实现以下我们的训练
设置参数
import argparse
from mindspore_rl.algorithm.dqn import config
from mindspore_rl.algorithm.dqn.dqn_session import DQNSession
from mindspore_rl.algorithm.dqn.dqn_trainer import DQNTrainer
from mindspore import context
from mindspore import dtype as mstype
parser = argparse.ArgumentParser(description='MindSpore Reinforcement DQN')
parser.add_argument('--episode', type=int, default=650, help='total episode numbers.')
parser.add_argument('--device_target', type=str, default='Auto', choices=['Ascend', 'CPU', 'GPU', 'Auto'],
help='Choose a device to run the dqn example(Default: Auto).')
parser.add_argument('--precision_mode', type=str, default='fp32', choices=['fp32', 'fp16'],
help='Precision mode')
parser.add_argument('--env_yaml', type=str, default='../env_yaml/CartPole-v0.yaml',
help='Choose an environment yaml to update the dqn example(Default: CartPole-v0.yaml).')
parser.add_argument('--algo_yaml', type=str, default=None,
help='Choose an algo yaml to update the dqn example(Default: None).')
options, _ = parser.parse_known_args()
配置训练算法
episode=options.episode
"""start to train dqn algorithm"""
if options.device_target != 'Auto':
context.set_context(device_target=options.device_target)
if context.get_context('device_target') in ['CPU']:
context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE, ascend_config={"precision_mode": "allow_mix_precision"})
compute_type = mstype.float32 if options.precision_mode == 'fp32' else mstype.float16
config.algorithm_config['policy_and_network']['params']['compute_type'] = compute_type
if compute_type == mstype.float16 and options.device_target != 'Ascend':
raise ValueError("Fp16 mode is supported by Ascend backend.")
dqn_session = DQNSession(options.env_yaml, options.algo_yaml)
开始训练
dqn_session.run(class_type=DQNTrainer, episode=episode)
可以看到MindRL提供了非常优秀的封装接口,我们可以通过打印出的内容绘制曲线。
添加一个实时上下文记录器来记录log
import sys
import time
from io import StringIO
class RealTimeCaptureAndDisplayOutput(object):
def __init__(self):
self._original_stdout = sys.stdout
self._original_stderr = sys.stderr
self.captured_output = StringIO()
def write(self, text):
self._original_stdout.write(text) # 实时打印
self.captured_output.write(text) # 保存到缓冲区
def flush(self):
self._original_stdout.flush()
self.captured_output.flush()
def __enter__(self):
sys.stdout = self
sys.stderr = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout = self._original_stdout
sys.stderr = self._original_stderr
# 使用这个上下文管理器
with RealTimeCaptureAndDisplayOutput() as captured:
print("Hello, World!")
time.sleep(5)
print("Another message!")
# 在这之后你可以使用 captured.captured_output.getvalue() 来获取所有捕获的输出
print("\nCaptured Output:")
print(captured.captured_output.getvalue())
# dqn_session.run(class_type=DQNTrainer, episode=episode)
with RealTimeCaptureAndDisplayOutput() as captured:
dqn_session.run(class_type=DQNTrainer, episode=100)
这样就能够更好地保存输出了
我们也可以手动构建如下的训练器:
class DQNTrainer(Trainer):
"""DQN Trainer"""
def __init__(self, msrl, params):
super(DQNTrainer, self).__init__(msrl)
self.zero = Tensor(0, ms.float32)
self.squeeze = P.Squeeze()
self.less = P.Less()
self.zero_value = Tensor(0, ms.float32)
self.fill_value = Tensor(1000, ms.float32)
self.inited = Parameter(Tensor((False,), ms.bool_), name="init_flag")
self.mod = P.Mod()
self.false = Tensor((False,), ms.bool_)
self.true = Tensor((True,), ms.bool_)
self.num_evaluate_episode = params["num_evaluate_episode"]
self.update_period = Tensor(5, ms.float32)
def trainable_variables(self):
"""Trainable variables for saving."""
trainable_variables = {"policy_net": self.msrl.learner.policy_network}
return trainable_variables
@ms.jit
def init_training(self):
"""Initialize training"""
state = self.msrl.collect_environment.reset()
done = self.false
i = self.zero_value
while self.less(i, self.fill_value):
done, _, new_state, action, my_reward = self.msrl.agent_act(
trainer.INIT, state
)
self.msrl.replay_buffer_insert([state, action, my_reward, new_state])
state = new_state
if done:
state = self.msrl.collect_environment.reset()
done = self.false
i += 1
return done
@ms.jit
def train_one_episode(self):
"""Train one episode"""
if not self.inited:
self.init_training()
self.inited = self.true
state = self.msrl.collect_environment.reset()
done = self.false
total_reward = self.zero
steps = self.zero
loss = self.zero
while not done:
done, r, new_state, action, my_reward = self.msrl.agent_act(
trainer.COLLECT, state
)
self.msrl.replay_buffer_insert([state, action, my_reward, new_state])
state = new_state
r = self.squeeze(r)
loss = self.msrl.agent_learn(self.msrl.replay_buffer_sample())
total_reward += r
steps += 1
if not self.mod(steps, self.update_period):
self.msrl.learner.update()
return loss, total_reward, steps
@ms.jit
def evaluate(self):
"""Policy evaluate"""
total_reward = self.zero_value
eval_iter = self.zero_value
while self.less(eval_iter, self.num_evaluate_episode):
episode_reward = self.zero_value
state = self.msrl.eval_environment.reset()
done = self.false
while not done:
done, r, state = self.msrl.agent_act(trainer.EVAL, state)
r = self.squeeze(r)
episode_reward += r
total_reward += episode_reward
eval_iter += 1
avg_reward = total_reward / self.num_evaluate_episode
return avg_reward
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)