MindSpore强化强化学习:使用mindrl

举报
irrational 发表于 2024/04/14 18:34:15 2024/04/14
【摘要】 安装MindRLpip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.1.0/Reinforcement/x86_64/mindspore_rl-0.7.0-py3-none-linux_x86_64.whlgit clone https://gitee.com/mindspore-lab/mindrl检验是否可...

安装MindRL

pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.1.0/Reinforcement/x86_64/mindspore_rl-0.7.0-py3-none-linux_x86_64.whl

git clone https://gitee.com/mindspore-lab/mindrl

检验是否可以正常使用。

cd mindrl/example/dqn
python train.py --episode 1000 --device_target GPU

image.png

那今天我们先在DQN的artPole-v0上实现以下我们的训练

设置参数

import argparse
from mindspore_rl.algorithm.dqn import config
from mindspore_rl.algorithm.dqn.dqn_session import DQNSession
from mindspore_rl.algorithm.dqn.dqn_trainer import DQNTrainer
from mindspore import context
from mindspore import dtype as mstype

parser = argparse.ArgumentParser(description='MindSpore Reinforcement DQN')
parser.add_argument('--episode', type=int, default=650, help='total episode numbers.')
parser.add_argument('--device_target', type=str, default='Auto', choices=['Ascend', 'CPU', 'GPU', 'Auto'],
                    help='Choose a device to run the dqn example(Default: Auto).')
parser.add_argument('--precision_mode', type=str, default='fp32', choices=['fp32', 'fp16'],
                    help='Precision mode')
parser.add_argument('--env_yaml', type=str, default='../env_yaml/CartPole-v0.yaml',
                    help='Choose an environment yaml to update the dqn example(Default: CartPole-v0.yaml).')
parser.add_argument('--algo_yaml', type=str, default=None,
                    help='Choose an algo yaml to update the dqn example(Default: None).')
options, _ = parser.parse_known_args()

配置训练算法

episode=options.episode
"""start to train dqn algorithm"""
if options.device_target != 'Auto':
    context.set_context(device_target=options.device_target)
if context.get_context('device_target') in ['CPU']:
    context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE, ascend_config={"precision_mode": "allow_mix_precision"})
compute_type = mstype.float32 if options.precision_mode == 'fp32' else mstype.float16
config.algorithm_config['policy_and_network']['params']['compute_type'] = compute_type
if compute_type == mstype.float16 and options.device_target != 'Ascend':
    raise ValueError("Fp16 mode is supported by Ascend backend.")
dqn_session = DQNSession(options.env_yaml, options.algo_yaml)

开始训练

dqn_session.run(class_type=DQNTrainer, episode=episode)

image.png
image.png

可以看到MindRL提供了非常优秀的封装接口,我们可以通过打印出的内容绘制曲线。
2dc47e12c0ffff4b81e32ca09d344be4.png
8cea1e7863ec4689d0a73cf30f5ca639.png

添加一个实时上下文记录器来记录log

import sys
import time
from io import StringIO

class RealTimeCaptureAndDisplayOutput(object):
    def __init__(self):
        self._original_stdout = sys.stdout
        self._original_stderr = sys.stderr
        self.captured_output = StringIO()

    def write(self, text):
        self._original_stdout.write(text)  # 实时打印
        self.captured_output.write(text)   # 保存到缓冲区

    def flush(self):
        self._original_stdout.flush()
        self.captured_output.flush()

    def __enter__(self):
        sys.stdout = self
        sys.stderr = self
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout = self._original_stdout
        sys.stderr = self._original_stderr

# 使用这个上下文管理器
with RealTimeCaptureAndDisplayOutput() as captured:
    print("Hello, World!")
    time.sleep(5)
    print("Another message!")

# 在这之后你可以使用 captured.captured_output.getvalue() 来获取所有捕获的输出
print("\nCaptured Output:")
print(captured.captured_output.getvalue())


# dqn_session.run(class_type=DQNTrainer, episode=episode)
with RealTimeCaptureAndDisplayOutput() as captured:
    dqn_session.run(class_type=DQNTrainer, episode=100)

这样就能够更好地保存输出了

image.png

我们也可以手动构建如下的训练器:

class DQNTrainer(Trainer):
    """DQN Trainer"""

    def __init__(self, msrl, params):
        super(DQNTrainer, self).__init__(msrl)
        self.zero = Tensor(0, ms.float32)
        self.squeeze = P.Squeeze()
        self.less = P.Less()
        self.zero_value = Tensor(0, ms.float32)
        self.fill_value = Tensor(1000, ms.float32)
        self.inited = Parameter(Tensor((False,), ms.bool_), name="init_flag")
        self.mod = P.Mod()
        self.false = Tensor((False,), ms.bool_)
        self.true = Tensor((True,), ms.bool_)
        self.num_evaluate_episode = params["num_evaluate_episode"]
        self.update_period = Tensor(5, ms.float32)

    def trainable_variables(self):
        """Trainable variables for saving."""
        trainable_variables = {"policy_net": self.msrl.learner.policy_network}
        return trainable_variables

    @ms.jit
    def init_training(self):
        """Initialize training"""
        state = self.msrl.collect_environment.reset()
        done = self.false
        i = self.zero_value
        while self.less(i, self.fill_value):
            done, _, new_state, action, my_reward = self.msrl.agent_act(
                trainer.INIT, state
            )
            self.msrl.replay_buffer_insert([state, action, my_reward, new_state])
            state = new_state
            if done:
                state = self.msrl.collect_environment.reset()
                done = self.false
            i += 1
        return done

    @ms.jit
    def train_one_episode(self):
        """Train one episode"""
        if not self.inited:
            self.init_training()
            self.inited = self.true
        state = self.msrl.collect_environment.reset()
        done = self.false
        total_reward = self.zero
        steps = self.zero
        loss = self.zero
        while not done:
            done, r, new_state, action, my_reward = self.msrl.agent_act(
                trainer.COLLECT, state
            )
            self.msrl.replay_buffer_insert([state, action, my_reward, new_state])
            state = new_state
            r = self.squeeze(r)
            loss = self.msrl.agent_learn(self.msrl.replay_buffer_sample())
            total_reward += r
            steps += 1
            if not self.mod(steps, self.update_period):
                self.msrl.learner.update()
        return loss, total_reward, steps

    @ms.jit
    def evaluate(self):
        """Policy evaluate"""
        total_reward = self.zero_value
        eval_iter = self.zero_value
        while self.less(eval_iter, self.num_evaluate_episode):
            episode_reward = self.zero_value
            state = self.msrl.eval_environment.reset()
            done = self.false
            while not done:
                done, r, state = self.msrl.agent_act(trainer.EVAL, state)
                r = self.squeeze(r)
                episode_reward += r
            total_reward += episode_reward
            eval_iter += 1
        avg_reward = total_reward / self.num_evaluate_episode
        return avg_reward
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。