【MADRL】多智能体价值分解网络(VDN)算法

举报
不去幼儿园 发表于 2024/12/20 11:11:56 2024/12/20
458 0 0
【摘要】 多智能体强化学习(MARL, Multi-Agent Reinforcement Learning)中,一个关键挑战是如何在多个智能体的协作环境下学习有效的策略。价值分解网络(VDN, Value Decomposition Network)是解决这一问题的一种重要方法,特别是在 集中训练,分散执行

         本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在强化学习专栏:

      【 强化学习】(11)---《多智能体价值分解网络(VDN)算法》

多智能体价值分解网络(VDN)算法

目录

1.算法介绍

2.VDN 算法概述

2.1价值分解

2.2训练过程

2.3损失函数

3.VDN算法的优势与局限

[Python] VDN伪代码实现

[Pytorch] VDN 实现(可移植)


1.算法介绍

        多智能体强化学习(MARL, Multi-Agent Reinforcement Learning)中,一个关键挑战是如何在多个智能体的协作环境下学习有效的策略。价值分解网络(VDN, Value Decomposition Network)是解决这一问题的一种重要方法,特别是在 集中训练,分散执行(CTDE, Centralized Training and Decentralized Execution)框架中,VDN提供了一种分解联合价值函数的策略,使得多个智能体可以高效协作并学习。

论文:Value-Decomposition Networks For Cooperative Multi-Agent Learning

代码: MADRL多智能体价值分解网络(VDN)算法

1.1背景与动机

        在多智能体系统中,每个智能体不仅需要根据自己的观察做出决策,还需要与其他智能体协作以实现全局目标。例如,在团队作战游戏中,每个智能体(如士兵)都有局部信息,但他们的行动需要协调以赢得整场比赛。这时,直接学习一个全局的Q值函数(即联合价值函数)来指导所有智能体的动作选择变得非常复杂,因为状态空间和动作空间随智能体数量呈指数增长。

        VDN提出了一种基于 联合价值函数的分解 方法,将全局Q值函数分解为多个独立智能体的局部Q值函数,从而使得问题规模显著降低,并能保证智能体之间的协作。


2.VDN 算法概述

        VDN算法的核心思想是将多个智能体的 联合Q值函数 分解为 每个智能体的局部Q值之和。在这种结构下,每个智能体学习自己的局部Q值函数 (Q_i),然后通过简单的求和操作得到全局的联合Q值(Q_{tot})。这一分解形式使得每个智能体可以独立执行决策,同时在集中训练阶段依然能够学到全局最优策略。

2.1价值分解

        在传统的单智能体强化学习中,Q值函数( Q(s, a) )表示在状态(s)下采取动作(a)的价值。对于多智能体系统,联合Q值函数( Q_{tot}(s, \mathbf{a}) )表示在状态(s)下所有智能体联合动作 ( \mathbf{a} = (a_1, a_2, \dots, a_N) ) 的总价值。

        VDN假设联合Q值函数可以通过每个智能体的局部Q值函数(Q_i(o_i, a_i))进行线性分解:

[ Q_{tot}(s, \mathbf{a}) = \sum_{i=1}^N Q_i(o_i, a_i) ]

其中:

  • (N) 是智能体的数量,
  • (o_i)是智能体(i)的局部观察,
  • (a_i)是智能体 (i) 的动作,
  • (Q_i(o_i, a_i))是智能体(i)基于自己的局部观察(o_i) 和动作(a_i)所学习到的局部Q值。

        这种线性分解的方式使得各个智能体可以在执行时独立做出动作选择,同时在集中训练时通过全局Q值函数来优化策略。

2.2训练过程

        VDN的训练采用集中训练、分散执行(CTDE)模式:

  • 集中训练:训练时可以访问所有智能体的全局信息,如全局状态 (s) 和联合动作 ( \mathbf{a} ),利用这些信息来计算全局的目标函数(如回报值)。同时,联合Q值函数通过局部Q值函数的和来计算和更新。
  • 分散执行:在执行阶段,智能体只能基于自己的局部信息(o_i) 和学习到的局部Q值函数 (Q_i(o_i, a_i)) 进行动作选择。

        通过这种方式,每个智能体都可以独立执行,而在训练阶段又能确保全局最优解的学习。

2.3损失函数

        训练过程中,VDN的损失函数与传统的Q-learning类似,基于 TD误差(Temporal Difference error)来更新Q值。对于给定的经验样本( (s, \mathbf{a}, r, s') ),损失函数为:

[ \mathcal{L} = \mathbb{E} \left[ \left( r + \gamma \max_{\mathbf{a}'} Q_{tot}(s', \mathbf{a}') - Q_{tot}(s, \mathbf{a}) \right)^2 \right] ]

其中:

  • (r)是环境给出的全局回报,
  • (\gamma)是折扣因子,
  • (s')是下一个状态,
  • (\mathbf{a}')是在下一个状态下的最优联合动作。

由于(Q_{tot}(s, \mathbf{a}))是通过每个局部Q值的和来计算的,更新 (Q_{tot})的同时会更新每个智能体的局部Q值 (Q_i)


3.VDN算法的优势与局限

3.1优势

  1. 简化联合Q值学习:VDN将全局Q值函数分解为多个局部Q值函数,显著减少了学习的复杂性,特别是对于有较多智能体的系统。
  2. 分散执行:每个智能体只需根据自己的局部观察和Q值进行决策,不依赖其他智能体的具体动作,适用于具有局部观测的多智能体任务。
  3. 协作能力:通过联合Q值函数的分解,VDN能够有效地促进智能体之间的协作学习,有利于解决团队协作任务。

3.2局限

  1. 线性分解的限制:VDN采用线性求和的方式分解联合Q值,这种方法虽然简单,但可能无法捕捉复杂的智能体之间的非线性协作关系。在某些场景下,简单的线性分解无法保证找到全局最优策略。

  2. 非完全的协作信息:虽然VDN能够在一定程度上促进协作,但由于局部Q值与联合Q值之间的联系较弱,可能导致智能体之间的信息交换不充分,尤其是在非完全协作的环境中,智能体可能无法充分学习到全局最优策略。

3.4VDN的扩展:QMIX

        为了克服VDN的线性分解限制,QMIX算法提出了一种非线性价值分解方法。与VDN不同,QMIX使用一个混合网络来学习非线性的联合Q值,能够捕捉智能体之间更加复杂的协作关系。QMIX 的核心思想是通过一个可混合网络将局部Q值映射为联合Q值,并保证联合Q值是单调递增的,以确保分散执行时的最优性。


 [Python] VDN伪代码实现

python
# VDN Algorithm
def VDN_train(batch, gamma):
    # Extract necessary data from the batch
    states, actions, rewards, next_states, done_flags = batch
    
    # Initialize the total Q-values (Q_tot)
    Q_tot = 0
    
    # Loop over each agent
    for agent_id in range(num_agents):
        # Get agent's local observations and actions
        local_observations = states[:, agent_id]
        local_actions = actions[:, agent_id]
        
        # Get the agent's local Q-values (Q_i)
        Q_i = agent_local_Q_function(local_observations, local_actions)
        
        # Sum the local Q-values to get Q_tot
        Q_tot += Q_i
    
    # Compute target Q_tot using next states
    next_Q_tot = 0
    for agent_id in range(num_agents):
        next_observations = next_states[:, agent_id]
        next_Q_i = agent_local_Q_function(next_observations)
        next_Q_tot += next_Q_i
    
    # Calculate the target using Bellman equation
    target_Q_tot = rewards + gamma * (1 - done_flags) * next_Q_tot
    
    # Calculate the TD error and update Q-functions
    loss = (Q_tot - target_Q_tot).pow(2).mean()
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    return loss

[Pytorch] VDN 实现(可移植)

        若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱,以便于及时分享给您(私信难以及时回复)。

import torch
import numpy as np
from torch.utils.tensorboard import SummaryWriter
from smac.env import StarCraft2Env
import argparse
from replay_buffer import ReplayBuffer
from qmix_smac import QMIX_SMAC
from normalization import Normalization
class Runner_QMIX_SMAC:
    def __init__(self, args, env_name, number, seed):
        self.args = args
        self.env_name = env_name
        self.number = number
        self.seed = seed
        # Set random seed
        np.random.seed(self.seed)
        torch.manual_seed(self.seed)
        # Create env
        self.env = StarCraft2Env(map_name=self.env_name, seed=self.seed)
        self.env_info = self.env.get_env_info()
        self.args.N = self.env_info["n_agents"]  # The number of agents
        self.args.obs_dim = self.env_info["obs_shape"]  # The dimensions of an agent's observation space
        self.args.state_dim = self.env_info["state_shape"]  # The dimensions of global state space
        self.args.action_dim = self.env_info["n_actions"]  # The dimensions of an agent's action space
        self.args.episode_limit = self.env_info["episode_limit"]  # Maximum number of steps per episode
        print("number of agents={}".format(self.args.N))
        print("obs_dim={}".format(self.args.obs_dim))
        print("state_dim={}".format(self.args.state_dim))
        print("action_dim={}".format(self.args.action_dim))
        print("episode_limit={}".format(self.args.episode_limit))

        # Create N agents
        self.agent_n = QMIX_SMAC(self.args)
        self.replay_buffer = ReplayBuffer(self.args)

        # Create a tensorboard
        self.writer = SummaryWriter(log_dir='./runs/{}/{}_env_{}_number_{}_seed_{}'.format(self.args.algorithm, self.args.algorithm, self.env_name, self.number, self.seed))

        self.epsilon = self.args.epsilon  # Initialize the epsilon
        self.win_rates = []  # Record the win rates
        self.total_steps = 0
        if self.args.use_reward_norm:
            print("------use reward norm------")
            self.reward_norm = Normalization(shape=1)

    def run(self, ):
        evaluate_num = -1  # Record the number of evaluations
        while self.total_steps < self.args.max_train_steps:
            if self.total_steps // self.args.evaluate_freq > evaluate_num:
                self.evaluate_policy()  # Evaluate the policy every 'evaluate_freq' steps
                evaluate_num += 1

            _, _, episode_steps = self.run_episode_smac(evaluate=False)  # Run an episode
            self.total_steps += episode_steps

            if self.replay_buffer.current_size >= self.args.batch_size:
                self.agent_n.train(self.replay_buffer, self.total_steps)  # Training

        self.evaluate_policy()
        self.env.close()

    def evaluate_policy(self, ):
        win_times = 0
        evaluate_reward = 0
        for _ in range(self.args.evaluate_times):
            win_tag, episode_reward, _ = self.run_episode_smac(evaluate=True)
            if win_tag:
                win_times += 1
            evaluate_reward += episode_reward

        win_rate = win_times / self.args.evaluate_times
        evaluate_reward = evaluate_reward / self.args.evaluate_times
        self.win_rates.append(win_rate)
        print("total_steps:{} \t win_rate:{} \t evaluate_reward:{}".format(self.total_steps, win_rate, evaluate_reward))
        self.writer.add_scalar('win_rate_{}'.format(self.env_name), win_rate, global_step=self.total_steps)
        # Save the win rates
        np.save('./data_train/{}_env_{}_number_{}_seed_{}.npy'.format(self.args.algorithm, self.env_name, self.number, self.seed), np.array(self.win_rates))

    def run_episode_smac(self, evaluate=False):
        win_tag = False
        episode_reward = 0
        self.env.reset()
        if self.args.use_rnn:  # If use RNN, before the beginning of each episode,reset the rnn_hidden of the Q network.
            self.agent_n.eval_Q_net.rnn_hidden = None
        last_onehot_a_n = np.zeros((self.args.N, self.args.action_dim))  # Last actions of N agents(one-hot)
        for episode_step in range(self.args.episode_limit):
            obs_n = self.env.get_obs()  # obs_n.shape=(N,obs_dim)
            s = self.env.get_state()  # s.shape=(state_dim,)
            avail_a_n = self.env.get_avail_actions()  # Get available actions of N agents, avail_a_n.shape=(N,action_dim)
            epsilon = 0 if evaluate else self.epsilon
            a_n = self.agent_n.choose_action(obs_n, last_onehot_a_n, avail_a_n, epsilon)
            last_onehot_a_n = np.eye(self.args.action_dim)[a_n]  # Convert actions to one-hot vectors
            r, done, info = self.env.step(a_n)  # Take a step
            win_tag = True if done and 'battle_won' in info and info['battle_won'] else False
            episode_reward += r

            if not evaluate:
                if self.args.use_reward_norm:
                    r = self.reward_norm(r)
                """"
                    When dead or win or reaching the episode_limit, done will be Ture, we need to distinguish them;
                    dw means dead or win,there is no next state s';
                    but when reaching the max_episode_steps,there is a next state s' actually.
                """
                if done and episode_step + 1 != self.args.episode_limit:
                    dw = True
                else:
                    dw = False

                # Store the transition
                self.replay_buffer.store_transition(episode_step, obs_n, s, avail_a_n, last_onehot_a_n, a_n, r, dw)
                # Decay the epsilon
                self.epsilon = self.epsilon - self.args.epsilon_decay if self.epsilon - self.args.epsilon_decay > self.args.epsilon_min else self.args.epsilon_min

            if done:
                break

        if not evaluate:
            # An episode is over, store obs_n, s and avail_a_n in the last step
            obs_n = self.env.get_obs()
            s = self.env.get_state()
            avail_a_n = self.env.get_avail_actions()
            self.replay_buffer.store_last_step(episode_step + 1, obs_n, s, avail_a_n)

        return win_tag, episode_reward, episode_step + 1
if __name__ == '__main__':
    parser = argparse.ArgumentParser("Hyperparameter Setting for QMIX and VDN in SMAC environment")
    parser.add_argument("--max_train_steps", type=int, default=int(1e6), help=" Maximum number of training steps")
    parser.add_argument("--evaluate_freq", type=float, default=5000, help="Evaluate the policy every 'evaluate_freq' steps")
    parser.add_argument("--evaluate_times", type=float, default=32, help="Evaluate times")
    parser.add_argument("--save_freq", type=int, default=int(1e5), help="Save frequency")

    parser.add_argument("--algorithm", type=str, default="VDN", help="QMIX or VDN")
    parser.add_argument("--epsilon", type=float, default=1.0, help="Initial epsilon")
    parser.add_argument("--epsilon_decay_steps", type=float, default=50000, help="How many steps before the epsilon decays to the minimum")
    parser.add_argument("--epsilon_min", type=float, default=0.05, help="Minimum epsilon")
    parser.add_argument("--buffer_size", type=int, default=5000, help="The capacity of the replay buffer")
    parser.add_argument("--batch_size", type=int, default=32, help="Batch size (the number of episodes)")
    parser.add_argument("--lr", type=float, default=5e-4, help="Learning rate")
    parser.add_argument("--gamma", type=float, default=0.99, help="Discount factor")
    parser.add_argument("--qmix_hidden_dim", type=int, default=32, help="The dimension of the hidden layer of the QMIX network")
    parser.add_argument("--hyper_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of the hyper-network")
    parser.add_argument("--hyper_layers_num", type=int, default=1, help="The number of layers of hyper-network")
    parser.add_argument("--rnn_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of RNN")
    parser.add_argument("--mlp_hidden_dim", type=int, default=64, help="The dimension of the hidden layer of MLP")
    parser.add_argument("--use_rnn", type=bool, default=True, help="Whether to use RNN")
    parser.add_argument("--use_orthogonal_init", type=bool, default=True, help="Orthogonal initialization")
    parser.add_argument("--use_grad_clip", type=bool, default=True, help="Gradient clip")
    parser.add_argument("--use_lr_decay", type=bool, default=False, help="use lr decay")
    parser.add_argument("--use_RMS", type=bool, default=False, help="Whether to use RMS,if False, we will use Adam")
    parser.add_argument("--add_last_action", type=bool, default=True, help="Whether to add last actions into the observation")
    parser.add_argument("--add_agent_id", type=bool, default=True, help="Whether to add agent id into the observation")
    parser.add_argument("--use_double_q", type=bool, default=True, help="Whether to use double q-learning")
    parser.add_argument("--use_reward_norm", type=bool, default=False, help="Whether to use reward normalization")
    parser.add_argument("--use_hard_update", type=bool, default=True, help="Whether to use hard update")
    parser.add_argument("--target_update_freq", type=int, default=200, help="Update frequency of the target network")
    parser.add_argument("--tau", type=int, default=0.005, help="If use soft update")

    args = parser.parse_args()
    args.epsilon_decay = (args.epsilon - args.epsilon_min) / args.epsilon_decay_steps

    env_names = ['3m', '8m', '2s3z']
    env_index = 0
    runner = Runner_QMIX_SMAC(args, env_name=env_names[env_index], number=1, seed=0)
    runner.run()

移植事项:

1.注意环境参数的设置格式

2.注意环境的返回值利用

3.注意主运行流程的runner.run()的相关设置,等

可借鉴:【MADRL】基于MADRL的单调价值函数分解(QMIX)算法​​​​​​ 中关于 QMIX算法移植的注意事项和代码注释。 或者将上述QMIX算法代码部分的"QMIX" 改成 "VDN"即可

parser.add_argument("--algorithm", type=str, default="VDN", help="QMIX or VDN")

     文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。

【版权声明】本文为华为云社区用户原创内容,未经允许不得转载,如需转载请自行联系原作者进行授权。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

作者其他文章

评论(0

抱歉,系统识别当前为高风险访问,暂不支持该操作

    全部回复

    上滑加载中

    设置昵称

    在此一键设置昵称,即可参与社区互动!

    *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

    *长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。