使用SAC算法控制倒立摆

举报
HWCloudAI 发表于 2022/11/28 20:07:24 2022/11/28
【摘要】 案例内容介绍 CartPoleContinuous是连续动作空间版本的CartPole,杆子通过一个未驱动的接头连接到推车上,推车沿着无摩擦的轨道移动。系统通过向推车施加+1或-1的力来控制。钟摆直立开始,目标是防止它摔倒。在本案例中,我们将展示如何基于SAC算法,良好控制CartPole。

使用SAC算法控制倒立摆-作业

欢迎你将完成的作业分享到 AI Gallery Notebook 版块获得成长值,分享方法请查看此文档

题目描述

请你调整步骤2中的训练参数,重新训练一个模型,使它在游戏中获得更好的表现

提示:

  1. 请在下文中搜索“# 请在此处实现代码”,注释所在之处就是你需要修改代码的地方
  2. 修改好代码之后,跑通整个案例代码,即可完成作业,请将完成的作业分享到AI Gallery,标题以“2021实战营”为开头命名

代码实现

1. 程序初始化

第1步:安装基础依赖

!pip install gym pybullet

第2步:导入相关的库

import time
import random
import itertools

import gym
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import Adam
from torch.distributions import Normal
import pybullet_envs

2. 训练参数初始化

本案例设置的 num_steps = 30000,可以达到200分,训练耗时约5分钟。

# 请在此处实现代码

3. 定义SAC算法

第1步:定义Q网络,Q1和Q2,结构相同,为[256,256,256]的全连接层

# Initialize Policy weights
def weights_init_(m):
    if isinstance(m, nn.Linear):
        torch.nn.init.xavier_uniform_(m.weight, gain=1)
        torch.nn.init.constant_(m.bias, 0)

class QNetwork(nn.Module):
    def __init__(self, num_inputs, num_actions):
        super(QNetwork, self).__init__()

        # Q1 architecture
        self.linear1 = nn.Linear(num_inputs + num_actions, 256)
        self.linear2 = nn.Linear(256, 256)
        self.linear3 = nn.Linear(256, 1)

        # Q2 architecture
        self.linear4 = nn.Linear(num_inputs + num_actions, 256)
        self.linear5 = nn.Linear(256, 256)
        self.linear6 = nn.Linear(256, 1)

        self.apply(weights_init_)

    def forward(self, state, action):
        xu = torch.cat([state, action], 1)

        x1 = F.relu(self.linear1(xu))
        x1 = F.relu(self.linear2(x1))
        x1 = self.linear3(x1)

        x2 = F.relu(self.linear4(xu))
        x2 = F.relu(self.linear5(x2))
        x2 = self.linear6(x2)

        return x1, x2

第2步:Policy网络,采用高斯分布,两层[256,256]全连接+均值+标准差

class GaussianPolicy(nn.Module):
    def __init__(self, num_inputs, num_actions, action_space=None):
        super(GaussianPolicy, self).__init__()

        self.linear1 = nn.Linear(num_inputs, 256)
        self.linear2 = nn.Linear(256, 256)

        self.mean_linear = nn.Linear(256, num_actions)
        self.log_std_linear = nn.Linear(256, num_actions)

        self.apply(weights_init_)

        # action rescaling
        if action_space is None:
            self.action_scale = torch.tensor(1.)
            self.action_bias = torch.tensor(0.)
        else:
            self.action_scale = torch.FloatTensor(
                (action_space.high - action_space.low) / 2.)
            self.action_bias = torch.FloatTensor(
                (action_space.high + action_space.low) / 2.)

    def forward(self, state):
        x = F.relu(self.linear1(state))
        x = F.relu(self.linear2(x))
        mean = self.mean_linear(x)
        log_std = self.log_std_linear(x)
        log_std = torch.clamp(log_std, min=LOG_SIG_MIN, max=LOG_SIG_MAX)
        return mean, log_std

    def sample(self, state):
        mean, log_std = self.forward(state)
        std = log_std.exp()
        normal = Normal(mean, std)
        # 重参数化技巧 (mean + std * N(0,1))
        x_t = normal.rsample()
        y_t = torch.tanh(x_t)
        action = y_t * self.action_scale + self.action_bias
        log_prob = normal.log_prob(x_t)
        log_prob -= torch.log(self.action_scale * (1 - y_t.pow(2)) + epsilon)
        log_prob = log_prob.sum(1, keepdim=True)
        mean = torch.tanh(mean) * self.action_scale + self.action_bias
        return action, log_prob, mean

    def to(self, device):
        self.action_scale = self.action_scale.to(device)
        self.action_bias = self.action_bias.to(device)
        return super(GaussianPolicy, self).to(device)

第3步: 定义sac训练部分

class SAC(object):
    def __init__(self, num_inputs, action_space):
        self.alpha = alpha
        self.auto_entropy = auto_entropy
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        # critic网络
        self.critic = QNetwork(num_inputs, action_space.shape[0]).to(device=self.device)
        self.critic_optim = Adam(self.critic.parameters(), lr=lr)
        # critic_target网络
        self.critic_target = QNetwork(num_inputs, action_space.shape[0]).to(self.device)
        hard_update(self.critic_target, self.critic)

        # Target Entropy = −dim(A)
        if auto_entropy is True:
            self.target_entropy = -torch.prod(torch.Tensor(action_space.shape).to(self.device)).item()
            self.log_alpha = torch.zeros(1, requires_grad=True, device=self.device)
            self.alpha_optim = Adam([self.log_alpha], lr=lr)

        self.policy = GaussianPolicy(num_inputs, action_space.shape[0], action_space).to(self.device)
        self.policy_optim = Adam(self.policy.parameters(), lr=lr)

    def select_action(self, state):
        state = torch.FloatTensor(state).to(self.device).unsqueeze(0)
        action, _, _ = self.policy.sample(state)
        return action.detach().cpu().numpy()[0]

    def update_parameters(self, memory, batch_size, updates):
        # Sample a batch from memory
        state_batch, action_batch, reward_batch, next_state_batch, mask_batch = memory.sample(batch_size=batch_size)

        state_batch = torch.FloatTensor(state_batch).to(self.device)
        next_state_batch = torch.FloatTensor(next_state_batch).to(self.device)
        action_batch = torch.FloatTensor(action_batch).to(self.device)
        reward_batch = torch.FloatTensor(reward_batch).to(self.device).unsqueeze(1)
        mask_batch = torch.FloatTensor(mask_batch).to(self.device).unsqueeze(1)

        with torch.no_grad():
            # 经过policy_network得到action
            next_state_action, next_state_log_pi, _ = self.policy.sample(next_state_batch)
            # 输入next_state,和next_action,经过target_critic_network得到Q值
            qf1_next_target, qf2_next_target = self.critic_target(next_state_batch, next_state_action)
            min_qf_next_target = torch.min(qf1_next_target, qf2_next_target) - self.alpha * next_state_log_pi
            next_q_value = reward_batch + mask_batch * gamma * (min_qf_next_target)
        # 将当前state,action输入critic_network得到Q值
        qf1, qf2 = self.critic(state_batch, action_batch)
        # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf1_loss = F.mse_loss(qf1, next_q_value)
        # JQ = 𝔼(st,at)~D[0.5(Q1(st,at) - r(st,at) - γ(𝔼st+1~p[V(st+1)]))^2]
        qf2_loss = F.mse_loss(qf2, next_q_value)
        qf_loss = qf1_loss + qf2_loss

        self.critic_optim.zero_grad()
        qf_loss.backward()
        self.critic_optim.step()

        pi, log_pi, _ = self.policy.sample(state_batch)

        qf1_pi, qf2_pi = self.critic(state_batch, pi)
        min_qf_pi = torch.min(qf1_pi, qf2_pi)

        # Jπ = 𝔼st∼D,εt∼N[α * logπ(f(εt;st)|st) − Q(st,f(εt;st))]
        policy_loss = ((self.alpha * log_pi) - min_qf_pi).mean()

        self.policy_optim.zero_grad()
        policy_loss.backward()
        self.policy_optim.step()

        if self.auto_entropy:
            alpha_loss = -(self.log_alpha * (log_pi + self.target_entropy).detach()).mean()

            self.alpha_optim.zero_grad()
            alpha_loss.backward()
            self.alpha_optim.step()

            self.alpha = self.log_alpha.exp()
        else:
            alpha_loss = torch.tensor(0.).to(self.device)

        if updates % target_update_interval == 0:
            soft_update(self.critic_target, self.critic, tau)


def soft_update(target, source, tau):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(target_param.data * (1.0 - tau) + param.data * tau)


def hard_update(target, source):
    for target_param, param in zip(target.parameters(), source.parameters()):
        target_param.data.copy_(param.data)

第4步:定义replay buffer,存储[s,a,r,s_,done]

class ReplayMemory:
    def __init__(self, capacity):
        random.seed(seed)
        self.capacity = capacity
        self.buffer = []
        self.position = 0

    def push(self, state, action, reward, next_state, done):
        if len(self.buffer) < self.capacity:
            self.buffer.append(None)
        self.buffer[self.position] = (state, action, reward, next_state, done)
        self.position = (self.position + 1) % self.capacity

    def sample(self, batch_size):
        batch = random.sample(self.buffer, batch_size)
        state, action, reward, next_state, done = map(np.stack, zip(*batch))
        return state, action, reward, next_state, done

    def __len__(self):
        return len(self.buffer)

4. 训练模型

初始化环境和算法

# 创建环境
env = gym.make(env_name)
# 设置随机数
env.seed(seed)
env.action_space.seed(seed)
torch.manual_seed(seed)
np.random.seed(seed)

# 创建agent
agent = SAC(env.observation_space.shape[0], env.action_space)

# replay buffer
memory = ReplayMemory(replay_size)

# 训练步数记录
total_numsteps = 0
updates = 0
max_reward = 0

开始训练

print('\ntraining...')
begin_t = time.time()
for i_episode in itertools.count(1):
    episode_reward = 0
    episode_steps = 0
    done = False
    state = env.reset()

    while not done:
        if start_steps > total_numsteps:
            # 随机采样过程
            action = env.action_space.sample()
        else:
            # 根据策略采样
            action = agent.select_action(state)

        if len(memory) > batch_size:
            # 每个step更新次数
            for i in range(updates_per_step):
                agent.update_parameters(memory, batch_size, updates)
                updates += 1

        # 执行该步
        next_state, reward, done, _ = env.step(action)
        # 更新记录参数
        episode_steps += 1
        total_numsteps += 1
        episode_reward += reward

        # -done
        mask = 1 if episode_steps == env._max_episode_steps else float(not done)

        # 存入buffer
        memory.push(state, action, reward, next_state, mask)

        # 更新state
        state = next_state

    # 达到终止条件后,停止
    if total_numsteps > num_steps:
        break

    if episode_reward >= max_reward:
        max_reward = episode_reward
        print("current_max_reward {}".format(max_reward))
        # 保存模型
        torch.save(agent.policy, "model.pt")

    print("Episode: {}, total numsteps: {}, reward: {}".format(i_episode, total_numsteps,round(episode_reward, 2)))

env.close()
print("finish! time cost is {}s".format(time.time() - begin_t))

5. 使用模型推理游戏

由于本内核可视化依赖于OpenGL,需要窗口显示,但当前环境暂不支持,因此无法可视化,请将代码下载到本地,取消 env.render() 这行代码的注释,可查看可视化效果。

# 可视化部分
model = torch.load("model.pt")
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
state = env.reset()
# env.render()
done = False
episode_reward = 0
while not done:
    _, _, action = model.sample(torch.FloatTensor(state).to(device).unsqueeze(0))
    action = action.detach().cpu().numpy()[0]
    next_state, reward, done, _ = env.step(action)
    episode_reward += reward
    # env.render()
    state = next_state
print(episode_reward)

可视化效果如下:

image.png

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。