MindSpore AC算法强化学习

举报
irrational 发表于 2024/06/04 12:17:36 2024/06/04
【摘要】 AC算法,也称为Actor-Critic算法,是强化学习中的一种重要方法。它结合了策略梯度方法和价值函数方法的优点,主要由两部分组成:演员(Actor)和评论家(Critic)。演员(Actor):负责根据当前状态选择动作。通常采用策略函数 π(a|s) 来表示在给定状态 s 下采取动作 a 的概率。目标是学习一种策略,以最大化长期的累积奖励。评论家(Critic):评估演员采取的动作有多好...

AC算法,也称为Actor-Critic算法,是强化学习中的一种重要方法。它结合了策略梯度方法和价值函数方法的优点,主要由两部分组成:演员(Actor)和评论家(Critic)。

  1. 演员(Actor)
    • 负责根据当前状态选择动作。
    • 通常采用策略函数 π(a|s) 来表示在给定状态 s 下采取动作 a 的概率。
    • 目标是学习一种策略,以最大化长期的累积奖励。
  2. 评论家(Critic)
    • 评估演员采取的动作有多好。
    • 使用价值函数 V(s) 或 Q(s, a) 来衡量在状态 s 或在状态 s 下采取动作 a 的预期回报。
    • 目标是准确预测未来的回报,以指导演员的决策。
  3. 训练过程
    • 演员根据当前策略选择动作,环境根据这一动作返回新的状态和奖励。
    • 评论家根据奖励和新状态来评估这一动作的价值,并提供反馈给演员。
    • 演员根据评论家的反馈通过策略梯度方法调整其策略,以提高未来动作的预期回报。
  4. 算法特点
    • 平衡探索与利用:AC 算法通过持续更新策略来平衡探索(探索新动作)和利用(重复已知的好动作)。
    • 减少方差:由于评论家的引导,演员的策略更新更加稳定,减少了策略梯度方法中的方差。
    • 适用性:AC 算法适用于离散和连续动作空间,可以处理复杂的决策问题。
      伪代码方面,Actor-Critic算法的一个典型流程包括以下步骤:
  5. 使用来自参与者网络的策略 πθ 对 {s_t, a_t} 进行采样。
  6. 评估优势函数 A_t,也称为TD误差 δt。在Actor-Critic算法中,优势函数是由评论者网络产生的。
  7. 使用特定表达式评估梯度。
  8. 更新策略参数 θ。
  9. 更新基于评价者的基于价值的RL(Q学习)的权重。δt等于优势函数。
  10. 重复以上步骤,直到找到最佳策略 πθ。
    这个算法框架是一个很好的起点,但要应用于实际还需要进一步的发展。主要挑战在于如何有效管理两个神经网络(演员和评论家)的梯度更新,并确保它们相互依赖和协调。

导入相关包

import argparse
from mindspore_rl.algorithm.ac.ac_trainer import ACTrainer
from mindspore_rl.algorithm.ac.ac_session import ACSession
from mindspore import context

parser = argparse.ArgumentParser(description='MindSpore Reinforcement AC')
parser.add_argument('--episode', type=int, default=1000, help='total episode numbers.')
parser.add_argument('--device_target', type=str, default='Auto', choices=['Ascend', 'CPU', 'GPU', 'Auto'],
                    help='Choose a device to run the ac example(Default: Auto).')
parser.add_argument('--env_yaml', type=str, default='../env_yaml/CartPole-v0.yaml',
                    help='Choose an environment yaml to update the ac example(Default: CartPole-v0.yaml).')
parser.add_argument('--algo_yaml', type=str, default=None,
                    help='Choose an algo yaml to update the ac example(Default: None).')
options, _ = parser.parse_known_args()

启动环境

episode=options.episode
"""start to train ac algorithm"""
if options.device_target != 'Auto':
    context.set_context(device_target=options.device_target)
if context.get_context('device_target') in ['CPU']:
    context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE)
ac_session = ACSession(options.env_yaml, options.algo_yaml)

上下文管理

import sys
import time
from io import StringIO

class RealTimeCaptureAndDisplayOutput(object):
    def __init__(self):
        self._original_stdout = sys.stdout
        self._original_stderr = sys.stderr
        self.captured_output = StringIO()

    def write(self, text):
        self._original_stdout.write(text)  # 实时打印
        self.captured_output.write(text)   # 保存到缓冲区

    def flush(self):
        self._original_stdout.flush()
        self.captured_output.flush()

    def __enter__(self):
        sys.stdout = self
        sys.stderr = self
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        sys.stdout = self._original_stdout
        sys.stderr = self._original_stderr
episode=100
# dqn_session.run(class_type=DQNTrainer, episode=episode)
with RealTimeCaptureAndDisplayOutput() as captured_new:
    ac_session.run(class_type=ACTrainer, episode=episode)
import re
import matplotlib.pyplot as plt

# 原始输出
raw_output = captured_new.captured_output.getvalue()

# 使用正则表达式从输出中提取loss和rewards
loss_pattern = r"loss is (\d+\.\d+)"
reward_pattern = r"rewards is (\d+\.\d+)"
loss_values = [float(match.group(1)) for match in re.finditer(loss_pattern, raw_output)]
reward_values = [float(match.group(1)) for match in re.finditer(reward_pattern, raw_output)]

# 绘制loss曲线
plt.plot(loss_values, label='Loss')
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.show()

# 绘制reward曲线
plt.plot(reward_values, label='Rewards')
plt.xlabel('Episode')
plt.ylabel('Rewards')
plt.title('Rewards Curve')
plt.legend()
plt.show()

image.png
image.png

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。