本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在强化学习专栏:
【 强化学习】(11)---《多智能体价值分解网络(VDN)算法》
多智能体价值分解网络(VDN)算法
目录
1.算法介绍
2.VDN 算法概述
2.1价值分解
2.2训练过程
2.3损失函数
3.VDN算法的优势与局限
[Python] VDN伪代码实现
[Pytorch] VDN 实现(可移植)
1.算法介绍
多智能体强化学习(MARL, Multi-Agent Reinforcement Learning)中,一个关键挑战是如何在多个智能体的协作环境下学习有效的策略。价值分解网络(VDN, Value Decomposition Network)是解决这一问题的一种重要方法,特别是在 集中训练,分散执行(CTDE, Centralized Training and Decentralized Execution)框架中,VDN提供了一种分解联合价值函数的策略,使得多个智能体可以高效协作并学习。
论文:Value-Decomposition Networks For Cooperative Multi-Agent Learning
代码: MADRL多智能体价值分解网络(VDN)算法
1.1背景与动机
在多智能体系统中,每个智能体不仅需要根据自己的观察做出决策,还需要与其他智能体协作以实现全局目标。例如,在团队作战游戏中,每个智能体(如士兵)都有局部信息,但他们的行动需要协调以赢得整场比赛。这时,直接学习一个全局的Q值函数(即联合价值函数)来指导所有智能体的动作选择变得非常复杂,因为状态空间和动作空间随智能体数量呈指数增长。
VDN提出了一种基于 联合价值函数的分解 方法,将全局Q值函数分解为多个独立智能体的局部Q值函数,从而使得问题规模显著降低,并能保证智能体之间的协作。
2.VDN 算法概述
VDN算法的核心思想是将多个智能体的 联合Q值函数 分解为 每个智能体的局部Q值之和。在这种结构下,每个智能体学习自己的局部Q值函数
,然后通过简单的求和操作得到全局的联合Q值
。这一分解形式使得每个智能体可以独立执行决策,同时在集中训练阶段依然能够学到全局最优策略。
2.1价值分解
在传统的单智能体强化学习中,Q值函数
表示在状态
下采取动作
的价值。对于多智能体系统,联合Q值函数
表示在状态
下所有智能体联合动作
的总价值。
VDN假设联合Q值函数可以通过每个智能体的局部Q值函数
进行线性分解:
![[ Q_{tot}(s, \mathbf{a}) = \sum_{i=1}^N Q_i(o_i, a_i) ]](https://res.hc-cdn.com/ecology/9.3.164/v2_resources/ydcomm/libs/images/loading.gif)
其中:
是智能体的数量,
是智能体
的局部观察,
是智能体
的动作,
是智能体
基于自己的局部观察
和动作
所学习到的局部Q值。
这种线性分解的方式使得各个智能体可以在执行时独立做出动作选择,同时在集中训练时通过全局Q值函数来优化策略。
2.2训练过程
VDN的训练采用集中训练、分散执行(CTDE)模式:
- 集中训练:训练时可以访问所有智能体的全局信息,如全局状态 (s) 和联合动作
,利用这些信息来计算全局的目标函数(如回报值)。同时,联合Q值函数通过局部Q值函数的和来计算和更新。
- 分散执行:在执行阶段,智能体只能基于自己的局部信息
和学习到的局部Q值函数
进行动作选择。
通过这种方式,每个智能体都可以独立执行,而在训练阶段又能确保全局最优解的学习。
2.3损失函数
训练过程中,VDN的损失函数与传统的Q-learning类似,基于 TD误差(Temporal Difference error)来更新Q值。对于给定的经验样本
,损失函数为:
![[ \mathcal{L} = \mathbb{E} \left[ \left( r + \gamma \max_{\mathbf{a}'} Q_{tot}(s', \mathbf{a}') - Q_{tot}(s, \mathbf{a}) \right)^2 \right] ]](https://res.hc-cdn.com/ecology/9.3.164/v2_resources/ydcomm/libs/images/loading.gif)
其中:
是环境给出的全局回报,
是折扣因子,
是下一个状态,
是在下一个状态下的最优联合动作。
由于
是通过每个局部Q值的和来计算的,更新
的同时会更新每个智能体的局部Q值
。
3.VDN算法的优势与局限
3.1优势
- 简化联合Q值学习:VDN将全局Q值函数分解为多个局部Q值函数,显著减少了学习的复杂性,特别是对于有较多智能体的系统。
- 分散执行:每个智能体只需根据自己的局部观察和Q值进行决策,不依赖其他智能体的具体动作,适用于具有局部观测的多智能体任务。
- 协作能力:通过联合Q值函数的分解,VDN能够有效地促进智能体之间的协作学习,有利于解决团队协作任务。
3.2局限
-
线性分解的限制:VDN采用线性求和的方式分解联合Q值,这种方法虽然简单,但可能无法捕捉复杂的智能体之间的非线性协作关系。在某些场景下,简单的线性分解无法保证找到全局最优策略。
-
非完全的协作信息:虽然VDN能够在一定程度上促进协作,但由于局部Q值与联合Q值之间的联系较弱,可能导致智能体之间的信息交换不充分,尤其是在非完全协作的环境中,智能体可能无法充分学习到全局最优策略。
3.4VDN的扩展:QMIX
为了克服VDN的线性分解限制,QMIX算法提出了一种非线性价值分解方法。与VDN不同,QMIX使用一个混合网络来学习非线性的联合Q值,能够捕捉智能体之间更加复杂的协作关系。QMIX 的核心思想是通过一个可混合网络将局部Q值映射为联合Q值,并保证联合Q值是单调递增的,以确保分散执行时的最优性。
[Python] VDN伪代码实现
[Pytorch] VDN 实现(可移植)
若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱,以便于及时分享给您(私信难以及时回复)。
import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from smac.env import StarCraft2Env
import argparse
from replay_buffer import ReplayBuffer
from qmix_smac import QMIX_SMAC
from normalization import Normalization
class Runner_QMIX_SMAC:
def __init__(self, args, env_name, number, seed):
self.args = args
self.env_name = env_name
self.number = number
self.seed = seed
np.random.seed(self.seed)
torch.manual_seed(self.seed)
self.env = StarCraft2Env(map_name=self.env_name, seed=self.seed)
self.env_info = self.env.get_env_info()
self.args.N = self.env_info["n_agents"]
self.args.obs_dim = self.env_info["obs_shape"]
self.args.state_dim = self.env_info["state_shape"]
self.args.action_dim = self.env_info["n_actions"]
self.args.episode_limit = self.env_info["episode_limit"]
print("number of agents={}".format(self.args.N))
print("obs_dim={}".format(self.args.obs_dim))
print("state_dim={}".format(self.args.state_dim))
print("action_dim={}".format(self.args.action_dim))
print("episode_limit={}".format(self.args.episode_limit))
self.agent_n = QMIX_SMAC(self.args)
self.replay_buffer = ReplayBuffer(self.args)
self.writer = SummaryWriter(log_dir='./runs/{}/{}_env_{}_number_{}_seed_{}'.format(self.args.algorithm, self.args.algorithm, self.env_name, self.number, self.seed))
self.epsilon = self.args.epsilon
self.win_rates = []
self.total_steps = 0
if self.args.use_reward_norm:
print("------use reward norm------")
self.reward_norm = Normalization(shape=1)
def run(self, ):
evaluate_num = -1
while self.total_steps < self.args.max_train_steps:
if self.total_steps // self.args.evaluate_freq > evaluate_num:
self.evaluate_policy()
evaluate_num += 1
_, _, episode_steps = self.run_episode_smac(evaluate=False)
self.total_steps += episode_steps
if self.replay_buffer.current_size >= self.args.batch_size:
self.agent_n.train(self.replay_buffer, self.total_steps)
self.evaluate_policy()
self.env.close()
def evaluate_policy(self, ):
win_times = 0
evaluate_reward = 0
for _ in range(self.args.evaluate_times):
win_tag, episode_reward, _ = self.run_episode_smac(evaluate=True)
if win_tag:
win_times += 1
evaluate_reward += episode_reward
win_rate = win_times / self.args.evaluate_times
evaluate_reward = evaluate_reward / self.args.evaluate_times
self.win_rates.append(win_rate)
print("total_steps:{} \t win_rate:{} \t evaluate_reward:{}".format(self.total_steps, win_rate, evaluate_reward))
self.writer.add_scalar('win_rate_{}'.format(self.env_name), win_rate, global_step=self.total_steps)
np.save('./data_train/{}_env_{}_number_{}_seed_{}.npy'.format(self.args.algorithm, self.env_name, self.number, self.seed), np.array(self.win_rates))
def run_episode_smac(self, evaluate=False):
win_tag = False
episode_reward = 0
self.env.reset()
if self.args.use_rnn:
self.agent_n.eval_Q_net.rnn_hidden = None
last_onehot_a_n = np.zeros((self.args.N, self.args.action_dim))
for episode_step in range(self.args.episode_limit):
obs_n = self.env.get_obs()
s = self.env.get_state()
avail_a_n = self.env.get_avail_actions()
epsilon = 0 if evaluate else self.epsilon
a_n = self.agent_n.choose_action(obs_n, last_onehot_a_n, avail_a_n, epsilon)
last_onehot_a_n = np.eye(self.args.action_dim)[a_n]
r, done, info = self.env.step(a_n)
win_tag = True if done and 'battle_won' in info and info['battle_won'] else False
episode_reward += r
if not evaluate:
if self.args.use_reward_norm:
r = self.reward_norm(r)
""""
When dead or win or reaching the episode_limit, done will be Ture, we need to distinguish them;
dw means dead or win,there is no next state s';
but when reaching the max_episode_steps,there is a next state s' actually.
"""
if done and episode_step + 1 != self.args.episode_limit:
dw = True
else:
dw = False
self.replay_buffer.store_transition(episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw)
self.epsilon = self.epsilon - self.args.epsilon_decay if self.epsilon - self.args.epsilon_decay > self.args.epsilon_min else self.args.epsilon_min
if done:
break
if not evaluate:
obs_n = self.env.get_obs()
s = self.env.get_state()
avail_a_n = self.env.get_avail_actions()
self.replay_buffer.store_last_step(episode_step + 1, obs_n, s, avail_a_n)
return win_tag, episode_reward, episode_step + 1
if __name__ == '__main__':
parser = argparse.ArgumentParser("Hyperparameter Setting for QMIX and VDN in SMAC environment")
parser.add_argument("--max_train_steps", type=int, default=int(1e6), help=" Maximum number of training steps")
parser.add_argument("--evaluate_freq", type=float, default=5000, help="Evaluate the policy every 'evaluate_freq' steps")
parser.add_argument("--evaluate_times", type=float, default=32, help="Evaluate times")
parser.add_argument("--save_freq", type=int, default=int(1e5), help="Save frequency")
parser.add_argument("--algorithm", type=str, default="VDN", help="QMIX or VDN")
parser.add_argument("--epsilon", type=float, default=1.0, help="Initial epsilon")
parser.add_argument("--epsilon_decay_steps", type=float, default=50000, help="How many steps before the epsilon decays to the minimum")
parser.add_argument("--epsilon_min", type=float, default=0.05, help="Minimum epsilon")
parser.add_argument("--buffer_size", type=int, default=5000, help="The capacity of the replay buffer")
parser.add_argument("--batch_size", type=int, default=32, help="Batch size (the number of episodes)")
parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
parser.add_argument("--qmix_hidden_dim", type=int, default=32, help="The dimension of the hidden layer of the QMIX network")
parser.add_argument("--hyper_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of the hyper-network")
parser.add_argument("--hyper_layers_num", type=int, default=1, help="The number of layers of hyper-network")
parser.add_argument("--rnn_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of RNN")
parser.add_argument("--mlp_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of MLP")
parser.add_argument("--use_rnn", type=bool, default=True, help="Whether to use RNN")
parser.add_argument("--use_orthogonal_init", type=bool, default=True, help="Orthogonal initialization")
parser.add_argument("--use_grad_clip", type=bool, default=True, help="Gradient clip")
parser.add_argument("--use_lr_decay", type=bool, default=False, help="use lr decay")
parser.add_argument("--use_RMS", type=bool, default=False, help="Whether to use RMS,if False, we will use Adam")
parser.add_argument("--add_last_action", type=bool, default=True, help="Whether to add last actions into the observation")
parser.add_argument("--add_agent_id", type=bool, default=True, help="Whether to add agent id into the observation")
parser.add_argument("--use_double_q", type=bool, default=True, help="Whether to use double q-learning")
parser.add_argument("--use_reward_norm", type=bool, default=False, help="Whether to use reward normalization")
parser.add_argument("--use_hard_update", type=bool, default=True, help="Whether to use hard update")
parser.add_argument("--target_update_freq", type=int, default=200, help="Update frequency of the target network")
parser.add_argument("--tau", type=int, default=0.005, help="If use soft update")
args = parser.parse_args()
args.epsilon_decay = (args.epsilon - args.epsilon_min) / args.epsilon_decay_steps
env_names = ['3m', '8m', '2s3z']
env_index = 0
runner = Runner_QMIX_SMAC(args, env_name=env_names[env_index], number=1, seed=0)
runner.run()
移植事项:
1.注意环境参数的设置格式
2.注意环境的返回值利用
3.注意主运行流程的runner.run()的相关设置,等
可借鉴:【MADRL】基于MADRL的单调价值函数分解(QMIX)算法 中关于 QMIX算法移植的注意事项和代码注释。 或者将上述QMIX算法代码部分的"QMIX" 改成 "VDN"即可
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。
评论(0)