本篇文章是博主强化学习RL领域学习时,用于个人学习、研究或者欣赏使用,并基于博主对相关等领域的一些理解而记录的学习摘录和笔记,若有不当和侵权之处,指出后将会立即改正,还望谅解。文章分类在强化学习专栏:
强化学习(9)---《【MADRL】反事实多智能体策略梯度(COMA)算法》
【MADRL】反事实多智能体策略梯度(COMA)算法
目录
0.介绍
1.算法背景和思想
2.公式推导
3.COMA 算法步骤
4.优势
5.应用场景
[Python] COMA 算法实现
0.介绍
反事实多智能体策略梯度法COMA (Counterfactual Multi-Agent Policy Gradient) 是一种面向多智能体协作问题的强化学习算法,旨在通过减少策略梯度的方差,来提升去中心化智能体的学习效果。
COMA 算法最早由 DeepMind 团队提出,论文标题为 "Counterfactual Multi-Agent Policy Gradients",由 Jakob Foerster 等人于 2017 年在 AAAI 会议上发表。适用于局部观察、去中心化决策的多智能体环境,特别是策略梯度方法下的合作问题。
参考文献:Counterfactual Multi-Agent Policy Gradients
1.算法背景和思想
在多智能体强化学习场景中,每个智能体在某一时刻只掌握局部的信息,无法全局观测环境状态。为了促进合作,各个智能体的动作对全局奖励有不同的贡献,因此需要一种有效的方法来分配奖励。COMA 引入了“反事实基线”(Counterfactual Baseline)的概念,专门用于降低多智能体策略梯度方法中的方差。
COMA 的核心思想是通过引入一个基线,该基线模拟在固定其他智能体动作的前提下,某个智能体选择不同动作时对全局奖励的影响,从而更精确地衡量当前动作的贡献,减少策略梯度更新中的方差。

2.公式推导
-
全局策略梯度:对于多智能体问题,每个智能体
的策略梯度可以表示为:

其中,
是智能体
在状态
下选择动作
的概率,
是该智能体在执行动作
时的动作价值函数。
-
反事实基线:为了减小方差,COMA 提出了反事实基线
,该基线衡量在保持其他智能体动作
不变的情况下,智能体
选择其他动作时的期望收益。具体公式为:

这里的
表示除智能体
之外,其他智能体的动作组合。
-
策略更新:有了反事实基线之后,COMA 中智能体
的策略更新公式变为:

其中,
是在当前状态下,所有智能体执行一组动作
时的全局奖励,而
是该动作的反事实基线。
-
全局值函数:COMA 中的值函数
和基线
都是通过集中化的学习进行优化的,虽然决策是去中心化的,但值函数和基线都依赖于全局的状态和动作信息。
3.COMA 算法步骤
- 初始化智能体策略和集中式的全局值函数。
- 智能体与环境交互,收集经验数据。
- 使用经验数据更新全局值函数
。
- 计算反事实基线
。
- 计算每个智能体的策略梯度,并更新策略参数。
- 重复上述过程,直至智能体策略收敛。
4.优势
- 减少策略梯度的方差:通过引入反事实基线,有效地减少了策略更新过程中的高方差问题,使得策略更新更加稳定。
- 适合多智能体协作环境:COMA 尤其适合智能体需要紧密合作的场景,比如多智能体联合行动或团队任务。
5.应用场景
- 多智能体合作任务,如多机器人协作、多无人机编队等。
- 需要去中心化控制、但全局奖励信息可用的场景。
[Python] COMA 算法实现
若是下面代码复现困难或者有问题,欢迎评论区留言;需要以整个项目形式的代码,请在评论区留下您的邮箱,以便于及时分享给您(私信难以及时回复)。
主要代码:
import torch
import os
from network.base_net import RNN
from network.commnet import CommNet
from network.g2anet import G2ANet
from network.coma_critic import ComaCritic
from common.utils import td_lambda_target
class COMA:
def __init__(self, args):
self.n_actions = args.n_actions
self.n_agents = args.n_agents
self.state_shape = args.state_shape
self.obs_shape = args.obs_shape
actor_input_shape = self.obs_shape
critic_input_shape = self._get_critic_input_shape()
if args.last_action:
actor_input_shape += self.n_actions
if args.reuse_network:
actor_input_shape += self.n_agents
self.args = args
if self.args.alg == 'coma':
print('Init alg coma')
self.eval_rnn = RNN(actor_input_shape, args)
elif self.args.alg == 'coma+commnet':
print('Init alg coma+commnet')
self.eval_rnn = CommNet(actor_input_shape, args)
elif self.args.alg == 'coma+g2anet':
print('Init alg coma+g2anet')
self.eval_rnn = G2ANet(actor_input_shape, args)
else:
raise Exception("No such algorithm")
self.eval_critic = ComaCritic(critic_input_shape, self.args)
self.target_critic = ComaCritic(critic_input_shape, self.args)
if self.args.cuda:
self.eval_rnn.cuda()
self.eval_critic.cuda()
self.target_critic.cuda()
self.model_dir = args.model_dir + '/' + args.alg + '/' + args.map
if self.args.load_model:
if os.path.exists(self.model_dir + '/rnn_params.pkl'):
path_rnn = self.model_dir + '/rnn_params.pkl'
path_coma = self.model_dir + '/critic_params.pkl'
map_location = 'cuda:0' if self.args.cuda else 'cpu'
self.eval_rnn.load_state_dict(torch.load(path_rnn, map_location=map_location))
self.eval_critic.load_state_dict(torch.load(path_coma, map_location=map_location))
print('Successfully load the model: {} and {}'.format(path_rnn, path_coma))
else:
raise Exception("No model!")
self.target_critic.load_state_dict(self.eval_critic.state_dict())
self.rnn_parameters = list(self.eval_rnn.parameters())
self.critic_parameters = list(self.eval_critic.parameters())
if args.optimizer == "RMS":
self.critic_optimizer = torch.optim.RMSprop(self.critic_parameters, lr=args.lr_critic)
self.rnn_optimizer = torch.optim.RMSprop(self.rnn_parameters, lr=args.lr_actor)
self.args = args
self.eval_hidden = None
def _get_critic_input_shape(self):
input_shape = self.state_shape
input_shape += self.obs_shape
input_shape += self.n_agents
input_shape += self.n_actions * self.n_agents * 2
return input_shape
def learn(self, batch, max_episode_len, train_step, epsilon):
episode_num = batch['o'].shape[0]
self.init_hidden(episode_num)
for key in batch.keys():
if key == 'u':
batch[key] = torch.tensor(batch[key], dtype=torch.long)
else:
batch[key] = torch.tensor(batch[key], dtype=torch.float32)
u, r, avail_u, terminated = batch['u'], batch['r'], batch['avail_u'], batch['terminated']
mask = (1 - batch["padded"].float()).repeat(1, 1, self.n_agents)
if self.args.cuda:
u = u.cuda()
mask = mask.cuda()
q_values = self._train_critic(batch, max_episode_len, train_step)
action_prob = self._get_action_prob(batch, max_episode_len, epsilon)
q_taken = torch.gather(q_values, dim=3, index=u).squeeze(3)
pi_taken = torch.gather(action_prob, dim=3, index=u).squeeze(3)
pi_taken[mask == 0] = 1.0
log_pi_taken = torch.log(pi_taken)
baseline = (q_values * action_prob).sum(dim=3, keepdim=True).squeeze(3).detach()
advantage = (q_taken - baseline).detach()
loss = - ((advantage * log_pi_taken) * mask).sum() / mask.sum()
self.rnn_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.rnn_parameters, self.args.grad_norm_clip)
self.rnn_optimizer.step()
def _get_critic_inputs(self, batch, transition_idx, max_episode_len):
obs, obs_next, s, s_next = batch['o'][:, transition_idx], batch['o_next'][:, transition_idx],\
batch['s'][:, transition_idx], batch['s_next'][:, transition_idx]
u_onehot = batch['u_onehot'][:, transition_idx]
if transition_idx != max_episode_len - 1:
u_onehot_next = batch['u_onehot'][:, transition_idx + 1]
else:
u_onehot_next = torch.zeros(*u_onehot.shape)
s = s.unsqueeze(1).expand(-1, self.n_agents, -1)
s_next = s_next.unsqueeze(1).expand(-1, self.n_agents, -1)
episode_num = obs.shape[0]
u_onehot = u_onehot.view((episode_num, 1, -1)).repeat(1, self.n_agents, 1)
u_onehot_next = u_onehot_next.view((episode_num, 1, -1)).repeat(1, self.n_agents, 1)
if transition_idx == 0:
u_onehot_last = torch.zeros_like(u_onehot)
else:
u_onehot_last = batch['u_onehot'][:, transition_idx - 1]
u_onehot_last = u_onehot_last.view((episode_num, 1, -1)).repeat(1, self.n_agents, 1)
inputs, inputs_next = [], []
inputs.append(s)
inputs_next.append(s_next)
inputs.append(obs)
inputs_next.append(obs_next)
inputs.append(u_onehot_last)
inputs_next.append(u_onehot)
'''
因为coma对于当前动作,输入的是其他agent的当前动作,不输入当前agent的动作,为了方便起见,每次虽然输入当前agent的
当前动作,但是将其置为0相量,也就相当于没有输入。
'''
action_mask = (1 - torch.eye(self.n_agents))
action_mask = action_mask.view(-1, 1).repeat(1, self.n_actions).view(self.n_agents, -1)
inputs.append(u_onehot * action_mask.unsqueeze(0))
inputs_next.append(u_onehot_next * action_mask.unsqueeze(0))
'''
因为当前的inputs三维的数据,每一维分别代表(episode编号,agent编号,inputs维度),直接在后面添加对应的向量
即可,比如给agent_0后面加(1, 0, 0, 0, 0),表示5个agent中的0号。而agent_0的数据正好在第0行,那么需要加的
agent编号恰好就是一个单位矩阵,即对角线为1,其余为0
'''
inputs.append(torch.eye(self.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
inputs_next.append(torch.eye(self.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
inputs = torch.cat([x.reshape(episode_num * self.n_agents, -1) for x in inputs], dim=1)
inputs_next = torch.cat([x.reshape(episode_num * self.n_agents, -1) for x in inputs_next], dim=1)
return inputs, inputs_next
def _get_q_values(self, batch, max_episode_len):
episode_num = batch['o'].shape[0]
q_evals, q_targets = [], []
for transition_idx in range(max_episode_len):
inputs, inputs_next = self._get_critic_inputs(batch, transition_idx, max_episode_len)
if self.args.cuda:
inputs = inputs.cuda()
inputs_next = inputs_next.cuda()
q_eval = self.eval_critic(inputs)
q_target = self.target_critic(inputs_next)
q_eval = q_eval.view(episode_num, self.n_agents, -1)
q_target = q_target.view(episode_num, self.n_agents, -1)
q_evals.append(q_eval)
q_targets.append(q_target)
q_evals = torch.stack(q_evals, dim=1)
q_targets = torch.stack(q_targets, dim=1)
return q_evals, q_targets
def _get_actor_inputs(self, batch, transition_idx):
obs, u_onehot = batch['o'][:, transition_idx], batch['u_onehot'][:]
episode_num = obs.shape[0]
inputs = []
inputs.append(obs)
if self.args.last_action:
if transition_idx == 0:
inputs.append(torch.zeros_like(u_onehot[:, transition_idx]))
else:
inputs.append(u_onehot[:, transition_idx - 1])
if self.args.reuse_network:
inputs.append(torch.eye(self.args.n_agents).unsqueeze(0).expand(episode_num, -1, -1))
inputs = torch.cat([x.reshape(episode_num * self.args.n_agents, -1) for x in inputs], dim=1)
return inputs
def _get_action_prob(self, batch, max_episode_len, epsilon):
episode_num = batch['o'].shape[0]
avail_actions = batch['avail_u']
action_prob = []
for transition_idx in range(max_episode_len):
inputs = self._get_actor_inputs(batch, transition_idx)
if self.args.cuda:
inputs = inputs.cuda()
self.eval_hidden = self.eval_hidden.cuda()
outputs, self.eval_hidden = self.eval_rnn(inputs, self.eval_hidden)
outputs = outputs.view(episode_num, self.n_agents, -1)
prob = torch.nn.functional.softmax(outputs, dim=-1)
action_prob.append(prob)
action_prob = torch.stack(action_prob, dim=1).cpu()
action_num = avail_actions.sum(dim=-1, keepdim=True).float().repeat(1, 1, 1, avail_actions.shape[-1])
action_prob = ((1 - epsilon) * action_prob + torch.ones_like(action_prob) * epsilon / action_num)
action_prob[avail_actions == 0] = 0.0
action_prob = action_prob / action_prob.sum(dim=-1, keepdim=True)
action_prob[avail_actions == 0] = 0.0
if self.args.cuda:
action_prob = action_prob.cuda()
return action_prob
def init_hidden(self, episode_num):
self.eval_hidden = torch.zeros((episode_num, self.n_agents, self.args.rnn_hidden_dim))
def _train_critic(self, batch, max_episode_len, train_step):
u, r, avail_u, terminated = batch['u'], batch['r'], batch['avail_u'], batch['terminated']
u_next = u[:, 1:]
padded_u_next = torch.zeros(*u[:, -1].shape, dtype=torch.long).unsqueeze(1)
u_next = torch.cat((u_next, padded_u_next), dim=1)
mask = (1 - batch["padded"].float()).repeat(1, 1, self.n_agents)
if self.args.cuda:
u = u.cuda()
u_next = u_next.cuda()
mask = mask.cuda()
q_evals, q_next_target = self._get_q_values(batch, max_episode_len)
q_values = q_evals.clone()
q_evals = torch.gather(q_evals, dim=3, index=u).squeeze(3)
q_next_target = torch.gather(q_next_target, dim=3, index=u_next).squeeze(3)
targets = td_lambda_target(batch, max_episode_len, q_next_target.cpu(), self.args)
if self.args.cuda:
targets = targets.cuda()
td_error = targets.detach() - q_evals
masked_td_error = mask * td_error
loss = (masked_td_error ** 2).sum() / mask.sum()
self.critic_optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.critic_parameters, self.args.grad_norm_clip)
self.critic_optimizer.step()
if train_step > 0 and train_step % self.args.target_update_cycle == 0:
self.target_critic.load_state_dict(self.eval_critic.state_dict())
return q_values
def save_model(self, train_step):
num = str(train_step // self.args.save_cycle)
if not os.path.exists(self.model_dir):
os.makedirs(self.model_dir)
torch.save(self.eval_critic.state_dict(), self.model_dir + '/' + num + '_critic_params.pkl')
torch.save(self.eval_rnn.state_dict(), self.model_dir + '/' + num + '_rnn_params.pkl')
文章若有不当和不正确之处,还望理解与指出。由于部分文字、图片等来源于互联网,无法核实真实出处,如涉及相关争议,请联系博主删除。如有错误、疑问和侵权,欢迎评论留言联系作者,或者关注VX公众号:Rain21321,联系作者。
评论(0)