MindSpore AC算法强化学习
【摘要】 AC算法,也称为Actor-Critic算法,是强化学习中的一种重要方法。它结合了策略梯度方法和价值函数方法的优点,主要由两部分组成:演员(Actor)和评论家(Critic)。演员(Actor):负责根据当前状态选择动作。通常采用策略函数 π(a|s) 来表示在给定状态 s 下采取动作 a 的概率。目标是学习一种策略,以最大化长期的累积奖励。评论家(Critic):评估演员采取的动作有多好...
AC算法,也称为Actor-Critic算法,是强化学习中的一种重要方法。它结合了策略梯度方法和价值函数方法的优点,主要由两部分组成:演员(Actor)和评论家(Critic)。
- 演员(Actor):
- 负责根据当前状态选择动作。
- 通常采用策略函数 π(a|s) 来表示在给定状态 s 下采取动作 a 的概率。
- 目标是学习一种策略,以最大化长期的累积奖励。
- 评论家(Critic):
- 评估演员采取的动作有多好。
- 使用价值函数 V(s) 或 Q(s, a) 来衡量在状态 s 或在状态 s 下采取动作 a 的预期回报。
- 目标是准确预测未来的回报,以指导演员的决策。
- 训练过程:
- 演员根据当前策略选择动作,环境根据这一动作返回新的状态和奖励。
- 评论家根据奖励和新状态来评估这一动作的价值,并提供反馈给演员。
- 演员根据评论家的反馈通过策略梯度方法调整其策略,以提高未来动作的预期回报。
- 算法特点:
- 平衡探索与利用:AC 算法通过持续更新策略来平衡探索(探索新动作)和利用(重复已知的好动作)。
- 减少方差:由于评论家的引导,演员的策略更新更加稳定,减少了策略梯度方法中的方差。
- 适用性:AC 算法适用于离散和连续动作空间,可以处理复杂的决策问题。
伪代码方面,Actor-Critic算法的一个典型流程包括以下步骤:
- 使用来自参与者网络的策略 πθ 对 {s_t, a_t} 进行采样。
- 评估优势函数 A_t,也称为TD误差 δt。在Actor-Critic算法中,优势函数是由评论者网络产生的。
- 使用特定表达式评估梯度。
- 更新策略参数 θ。
- 更新基于评价者的基于价值的RL(Q学习)的权重。δt等于优势函数。
- 重复以上步骤,直到找到最佳策略 πθ。
这个算法框架是一个很好的起点,但要应用于实际还需要进一步的发展。主要挑战在于如何有效管理两个神经网络(演员和评论家)的梯度更新,并确保它们相互依赖和协调。
导入相关包
import argparse
from mindspore_rl.algorithm.ac.ac_trainer import ACTrainer
from mindspore_rl.algorithm.ac.ac_session import ACSession
from mindspore import context
parser = argparse.ArgumentParser(description='MindSpore Reinforcement AC')
parser.add_argument('--episode', type=int, default=1000, help='total episode numbers.')
parser.add_argument('--device_target', type=str, default='Auto', choices=['Ascend', 'CPU', 'GPU', 'Auto'],
help='Choose a device to run the ac example(Default: Auto).')
parser.add_argument('--env_yaml', type=str, default='../env_yaml/CartPole-v0.yaml',
help='Choose an environment yaml to update the ac example(Default: CartPole-v0.yaml).')
parser.add_argument('--algo_yaml', type=str, default=None,
help='Choose an algo yaml to update the ac example(Default: None).')
options, _ = parser.parse_known_args()
启动环境
episode=options.episode
"""start to train ac algorithm"""
if options.device_target != 'Auto':
context.set_context(device_target=options.device_target)
if context.get_context('device_target') in ['CPU']:
context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE)
ac_session = ACSession(options.env_yaml, options.algo_yaml)
上下文管理
import sys
import time
from io import StringIO
class RealTimeCaptureAndDisplayOutput(object):
def __init__(self):
self._original_stdout = sys.stdout
self._original_stderr = sys.stderr
self.captured_output = StringIO()
def write(self, text):
self._original_stdout.write(text) # 实时打印
self.captured_output.write(text) # 保存到缓冲区
def flush(self):
self._original_stdout.flush()
self.captured_output.flush()
def __enter__(self):
sys.stdout = self
sys.stderr = self
return self
def __exit__(self, exc_type, exc_val, exc_tb):
sys.stdout = self._original_stdout
sys.stderr = self._original_stderr
episode=100
# dqn_session.run(class_type=DQNTrainer, episode=episode)
with RealTimeCaptureAndDisplayOutput() as captured_new:
ac_session.run(class_type=ACTrainer, episode=episode)
import re
import matplotlib.pyplot as plt
# 原始输出
raw_output = captured_new.captured_output.getvalue()
# 使用正则表达式从输出中提取loss和rewards
loss_pattern = r"loss is (\d+\.\d+)"
reward_pattern = r"rewards is (\d+\.\d+)"
loss_values = [float(match.group(1)) for match in re.finditer(loss_pattern, raw_output)]
reward_values = [float(match.group(1)) for match in re.finditer(reward_pattern, raw_output)]
# 绘制loss曲线
plt.plot(loss_values, label='Loss')
plt.xlabel('Episode')
plt.ylabel('Loss')
plt.title('Loss Curve')
plt.legend()
plt.show()
# 绘制reward曲线
plt.plot(reward_values, label='Rewards')
plt.xlabel('Episode')
plt.ylabel('Rewards')
plt.title('Rewards Curve')
plt.legend()
plt.show()
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱:
cloudbbs@huaweicloud.com
- 点赞
- 收藏
- 关注作者
评论(0)