使用DPPO算法控制“倒立摆”

举报
HWCloudAI 发表于 2022/11/29 10:45:40 2022/11/29
【摘要】 倒立摆(Pendulum)摆动问题是控制文献中的经典问题。在我们本节用DPPO解决的Pendulum-v0问题中,钟摆从一个随机位置开始围绕一个端点摆动,目标是把钟摆向上摆动,并且是钟摆保持直立。

使用DPPO算法控制“倒立摆”

实验目标

通过本案例的学习和课后作业的练习:

  1. 了解DPPO基本概念
  2. 了解如何基于DPPO训练一个控制类问题
  3. 了解强化学习训练推理控制类问题的整体流程

你也可以将本案例相关的 ipynb 学习笔记分享到 AI Gallery Notebook 版块获得成长值,分享方法请查看此文档

案例内容介绍

倒立摆(Pendulum)摆动问题是控制文献中的经典问题。在我们本节用DPPO解决的Pendulum-v0问题中,钟摆从一个随机位置开始围绕一个端点摆动,目标是把钟摆向上摆动,并且是钟摆保持直立。一个随机动作的倒立摆demo如下所示:




整体流程:安装基础依赖->创建倒立摆环境->构建DPPO算法->训练->推理->可视化效果

Distributed Proximal Policy Optimization (DPPO) 算法的基本结构

DPPO算法是在Proximal Policy Optimization(PPO)算法基础上发展而来,相关PPO算法请看 使用PPO算法玩“超级马里奥兄弟”,我们在这一教程中有详细的介绍。DPPO借鉴A3C的并行方法,使用多个workers并行地在不同的环境中收集数据,并根据采集的数据计算梯度,将梯度发送给一个全局chief,全局chief在拿到一定数量的梯度数据之后进行网络更新,更新时workers停止采集等待等下完毕,更新完毕之后workers重新使用最新的网络采集数据。

下面我们使用论文中的伪代码介绍DPPO的具体算法细节。

上述算法所示为全局PPO的伪代码,其中 W 是workers的数目,D 是用于更新参数的works数量阈值,M,B是给定一批数据点的具有policy网络和critic网络更新的子迭代数,θ, Φ为policy网络,critic网络的参数。



上述算法所示为workers的伪代码,其中T是在计算参数更新之前收集的每个工作节点的数据点数,K是计算K步返回和通过时间截断的反向道具的时间步数(对于RNN)。 该部分算法基于PPO,首先采集数据,根据PPO算法计算梯度并将梯度发送给全局chief,等待全局chief更新完毕参数再进行数据的采集。

DPPO论文

代码部分参考GitHub开源项目

注意事项

  1. 本案例运行环境为 TensorFlow-2.0.0,且需使用 GPU 运行,请查看《ModelAtrs JupyterLab 硬件规格使用指南》了解切换硬件规格的方法;

  2. 如果您是第一次使用 JupyterLab,请查看《ModelAtrs JupyterLab使用指导》了解使用方法;

  3. 如果您在使用 JupyterLab 过程中碰到报错,请参考《ModelAtrs JupyterLab常见问题解决办法》尝试解决问题。

实验步骤

1. 程序初始化

第1步:安装基础依赖

!pip install tensorflow==2.0.0
!pip install tensorflow-probability==0.7.0
!pip install tensorlayer==2.1.0 --ignore-installed
!pip install h5py==2.10.0
!pip install gym

第2步:导入相关的库

import os
import time
import queue
import threading

import gym
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tensorlayer as tl

2. 训练参数初始化

本案例设置的 训练最大局数 EP_MAX = 1000,可以达到较好的训练效果,训练耗时约20分钟。

你也可以调小 EP_MAX 的值,以便快速跑通代码。

RANDOMSEED = 1  # 随机数种子

EP_MAX = 100  # 训练总局数
EP_LEN = 200  # 一局最长长度
GAMMA = 0.9  # 折扣率
A_LR = 0.0001  # actor学习率
C_LR = 0.0002  # critic学习率
BATCH = 32  # batchsize大小
A_UPDATE_STEPS = 10  # actor更新步数
C_UPDATE_STEPS = 10  # critic更新步数
S_DIM, A_DIM = 3, 1  # state维度, action维度
EPS = 1e-8  # epsilon值

# PPO1 和PPO2 的参数,你可以选择用PPO1 (METHOD[0]),还是PPO2 (METHOD[1])
METHOD = [
    dict(name='kl_pen', kl_target=0.01, lam=0.5),  # KL penalty
    dict(name='clip', epsilon=0.2),  # Clipped surrogate objective, find this is better
][1]  # choose the method for optimization

N_WORKER = 4  # 并行workers数目
MIN_BATCH_SIZE = 64  # 更新PPO的minibatch大小
UPDATE_STEP = 10  # 每隔10steps更新一次

3. 创建环境

本环境为gym内置的Pendulum,倒立摆倒下即失败。

env_name = 'Pendulum-v0'  # environment name

4. 定义DPPO算法

DPPO算法-PPO算法

class PPO(object):
    '''
    PPO class
    '''

    def __init__(self):

        # 创建critic
        tfs = tl.layers.Input([None, S_DIM], tf.float32, 'state')
        l1 = tl.layers.Dense(100, tf.nn.relu)(tfs)
        v = tl.layers.Dense(1)(l1)
        self.critic = tl.models.Model(tfs, v)
        self.critic.train()

        # 创建actor
        self.actor = self._build_anet('pi', trainable=True)
        self.actor_old = self._build_anet('oldpi', trainable=False)
        self.actor_opt = tf.optimizers.Adam(A_LR)
        self.critic_opt = tf.optimizers.Adam(C_LR)

    # 更新actor
    def a_train(self, tfs, tfa, tfadv):
        '''
        Update policy network
        :param tfs: state
        :param tfa: act
        :param tfadv: advantage
        :return:
        '''
        tfs = np.array(tfs, np.float32)
        tfa = np.array(tfa, np.float32)
        tfadv = np.array(tfadv, np.float32)  # td-error
        with tf.GradientTape() as tape:
            mu, sigma = self.actor(tfs)
            pi = tfp.distributions.Normal(mu, sigma)

            mu_old, sigma_old = self.actor_old(tfs)
            oldpi = tfp.distributions.Normal(mu_old, sigma_old)

            ratio = pi.prob(tfa) / (oldpi.prob(tfa) + EPS)
            surr = ratio * tfadv

            ## PPO1
            if METHOD['name'] == 'kl_pen':
                tflam = METHOD['lam']
                kl = tfp.distributions.kl_divergence(oldpi, pi)
                kl_mean = tf.reduce_mean(kl)
                aloss = -(tf.reduce_mean(surr - tflam * kl))
            ## PPO2
            else:
                aloss = -tf.reduce_mean(
                    tf.minimum(surr,
                               tf.clip_by_value(ratio, 1. - METHOD['epsilon'], 1. + METHOD['epsilon']) * tfadv)
                )
        a_gard = tape.gradient(aloss, self.actor.trainable_weights)

        self.actor_opt.apply_gradients(zip(a_gard, self.actor.trainable_weights))

        if METHOD['name'] == 'kl_pen':
            return kl_mean

    # 更新old_pi
    def update_old_pi(self):
        '''
        Update old policy parameter
        :return: None
        '''
        for p, oldp in zip(self.actor.trainable_weights, self.actor_old.trainable_weights):
            oldp.assign(p)

    # 更新critic
    def c_train(self, tfdc_r, s):
        '''
        Update actor network
        :param tfdc_r: cumulative reward
        :param s: state
        :return: None
        '''
        tfdc_r = np.array(tfdc_r, dtype=np.float32)
        with tf.GradientTape() as tape:
            advantage = tfdc_r - self.critic(s)  # 计算advantage:V(s') * gamma + r - V(s)
            closs = tf.reduce_mean(tf.square(advantage))
        grad = tape.gradient(closs, self.critic.trainable_weights)
        self.critic_opt.apply_gradients(zip(grad, self.critic.trainable_weights))

    # 计算advantage:V(s') * gamma + r - V(s)
    def cal_adv(self, tfs, tfdc_r):
        '''
        Calculate advantage
        :param tfs: state
        :param tfdc_r: cumulative reward
        :return: advantage
        '''
        tfdc_r = np.array(tfdc_r, dtype=np.float32)
        advantage = tfdc_r - self.critic(tfs)
        return advantage.numpy()

    def update(self):
        '''
        Update parameter with the constraint of KL divergent
        :return: None
        '''
        global GLOBAL_UPDATE_COUNTER
        while not COORD.should_stop():  # 如果协调器没有停止
            if GLOBAL_EP < EP_MAX:  # EP_MAX是最大更新次数
                UPDATE_EVENT.wait()  # PPO进程的等待位置
                self.update_old_pi()  # copy pi to old pi
                data = [QUEUE.get() for _ in range(QUEUE.qsize())]  # collect data from all workers
                data = np.vstack(data)

                s, a, r = data[:, :S_DIM].astype(np.float32), \
                          data[:, S_DIM: S_DIM + A_DIM].astype(np.float32), \
                          data[:, -1:].astype(np.float32)

                adv = self.cal_adv(s, r)

                # update actor
                ## PPO1
                if METHOD['name'] == 'kl_pen':
                    for _ in range(A_UPDATE_STEPS):
                        kl = self.a_train(s, a, adv)
                        if kl > 4 * METHOD['kl_target']:  # this in in google's paper
                            break
                    if kl < METHOD['kl_target'] / 1.5:  # adaptive lambda, this is in OpenAI's paper
                        METHOD['lam'] /= 2
                    elif kl > METHOD['kl_target'] * 1.5:
                        METHOD['lam'] *= 2
                    # sometimes explode, this clipping is MorvanZhou's solution
                    METHOD['lam'] = np.clip(METHOD['lam'], 1e-4, 10)

                ## PPO2
                else:  # clipping method, find this is better (OpenAI's paper)
                    for _ in range(A_UPDATE_STEPS):
                        self.a_train(s, a, adv)

                # 更新critic
                for _ in range(C_UPDATE_STEPS):
                    self.c_train(r, s)

                UPDATE_EVENT.clear()  # updating finished
                GLOBAL_UPDATE_COUNTER = 0  # reset counter
                ROLLING_EVENT.set()  # set roll-out available

    # 构建actor网络
    def _build_anet(self, name, trainable):
        '''
        Build policy network
        :param name: name
        :param trainable: trainable flag
        :return: policy network
        '''
        tfs = tl.layers.Input([None, S_DIM], tf.float32, name + '_state')
        l1 = tl.layers.Dense(100, tf.nn.relu, name=name + '_l1')(tfs)
        a = tl.layers.Dense(A_DIM, tf.nn.tanh, name=name + '_a')(l1)
        mu = tl.layers.Lambda(lambda x: x * 2, name=name + '_lambda')(a)
        sigma = tl.layers.Dense(A_DIM, tf.nn.softplus, name=name + '_sigma')(l1)
        model = tl.models.Model(tfs, [mu, sigma], name)

        if trainable:
            model.train()
        else:
            model.eval()
        return model

    # 选择动作
    def choose_action(self, s):
        '''
        Choose action
        :param s: state
        :return: clipped act
        '''
        s = s[np.newaxis, :].astype(np.float32)
        mu, sigma = self.actor(s)
        pi = tfp.distributions.Normal(mu, sigma)
        a = tf.squeeze(pi.sample(1), axis=0)[0]  # choosing action
        return np.clip(a, -2, 2)

    # 计算V()
    def get_v(self, s):
        '''
        Compute value
        :param s: state
        :return: value
        '''
        s = s.astype(np.float32)
        if s.ndim < 2: s = s[np.newaxis, :]
        return self.critic(s)[0, 0]

    def save_ckpt(self):
        """
        save trained weights
        :return: None
        """
        if not os.path.exists('model_Pendulum'):
            os.makedirs('model_Pendulum')
        tl.files.save_weights_to_hdf5('model_Pendulum/dppo_actor.hdf5', self.actor)
        tl.files.save_weights_to_hdf5('model_Pendulum/dppo_actor_old.hdf5', self.actor_old)
        tl.files.save_weights_to_hdf5('model_Pendulum/dppo_critic.hdf5', self.critic)

    def load_ckpt(self):
        """
        load trained weights
        :return: None
        """
        tl.files.load_hdf5_to_weights_in_order('model_Pendulum/dppo_actor.hdf5', self.actor)
        tl.files.load_hdf5_to_weights_in_order('model_Pendulum/dppo_actor_old.hdf5', self.actor_old)
        tl.files.load_hdf5_to_weights_in_order('model_Pendulum/dppo_critic.hdf5', self.critic)

workers构建

class Worker(object):
    '''
    Worker class for distributional running
    '''

    def __init__(self, wid):
        self.wid = wid  # 工号
        self.env = gym.make(env_name).unwrapped  # 创建环境
        self.env.seed(wid * 100 + RANDOMSEED)  # 设置不同的随机种子,因为不希望每个worker的都一致
        self.ppo = GLOBAL_PPO  # 算法

    def work(self):
        '''
        Define a worker
        :return: None
        '''
        global GLOBAL_EP, GLOBAL_RUNNING_R, GLOBAL_UPDATE_COUNTER
        while not COORD.should_stop():  # 从COORD接受消息,看看是否应该should_stop
            s = self.env.reset()
            ep_r = 0
            buffer_s, buffer_a, buffer_r = [], [], []  # 记录data
            t0 = time.time()
            for t in range(EP_LEN):
                # 看是否正在被更新。PPO进程正在工作,那么就在这里等待
                if not ROLLING_EVENT.is_set():  # 查询进程是否被阻塞,如果在阻塞状态,就证明如果global PPO正在更新。否则就可以继续。
                    ROLLING_EVENT.wait()  # worker进程的等待位置。wait until PPO is updated
                    buffer_s, buffer_a, buffer_r = [], [], []  # clear history buffer, use new policy to collect data

                # 正常跑游戏,并搜集数据
                a = self.ppo.choose_action(s)
                s_, r, done, _ = self.env.step(a)
                buffer_s.append(s)
                buffer_a.append(a)
                buffer_r.append((r + 8) / 8)  # normalize reward, find to be useful
                s = s_
                ep_r += r

                # GLOBAL_UPDATE_COUNTER是每个work的在游戏中进行一步,也就是产生一条数据就会+1.
                # 当GLOBAL_UPDATE_COUNTER大于batch(64)的时候,就可以进行更新。
                GLOBAL_UPDATE_COUNTER += 1  # count to minimum batch size, no need to wait other workers
                if t == EP_LEN - 1 or GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:  # t == EP_LEN - 1 是最后一步
                    ## 计算每个状态对应的V(s')
                    ## 要注意,这里的len(buffer) < GLOBAL_UPDATE_COUNTER。所以数据是每个worker各自计算的。
                    v_s_ = self.ppo.get_v(s_)
                    discounted_r = []  # compute discounted reward
                    for r in buffer_r[::-1]:
                        v_s_ = r + GAMMA * v_s_
                        discounted_r.append(v_s_)
                    discounted_r.reverse()

                    ## 堆叠成数据,并保存到公共队列中。
                    bs, ba, br = np.vstack(buffer_s), np.vstack(buffer_a), np.array(discounted_r)[:, np.newaxis]
                    buffer_s, buffer_a, buffer_r = [], [], []
                    QUEUE.put(np.hstack((bs, ba, br)))  # put data in the queue

                    # 如果数据足够,就开始更新
                    if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
                        ROLLING_EVENT.clear()  # stop collecting data
                        UPDATE_EVENT.set()  # global PPO update

                    if GLOBAL_EP >= EP_MAX:  # stop training
                        COORD.request_stop()  # 停止更新
                        break

            # record reward changes, plot later
            if len(GLOBAL_RUNNING_R) == 0:
                GLOBAL_RUNNING_R.append(ep_r)
            else:
                GLOBAL_RUNNING_R.append(GLOBAL_RUNNING_R[-1] * 0.9 + ep_r * 0.1)
            GLOBAL_EP += 1

            print(
                'Episode: {}/{}  | Worker: {} | Episode Reward: {:.4f}  | Running Time: {:.4f}'.format(
                    GLOBAL_EP, EP_MAX, self.wid, ep_r,
                    time.time() - t0
                )
            )

5. 模型训练

np.random.seed(RANDOMSEED)
tf.random.set_seed(RANDOMSEED)

GLOBAL_PPO = PPO()
[TL] Input  state: [None, 3]

[TL] Dense  dense_1: 100 relu

[TL] Dense  dense_2: 1 No Activation

[TL] Input  pi_state: [None, 3]

[TL] Dense  pi_l1: 100 relu

[TL] Dense  pi_a: 1 tanh

[TL] Lambda  pi_lambda: func: <function PPO._build_anet.<locals>.<lambda> at 0x7fb65d633950>, len_weights: 0

[TL] Dense  pi_sigma: 1 softplus

[TL] Input  oldpi_state: [None, 3]

[TL] Dense  oldpi_l1: 100 relu

[TL] Dense  oldpi_a: 1 tanh

[TL] Lambda  oldpi_lambda: func: <function PPO._build_anet.<locals>.<lambda> at 0x7fb65d633a70>, len_weights: 0

[TL] Dense  oldpi_sigma: 1 softplus
# 定义两组不同的事件,update 和 rolling
UPDATE_EVENT, ROLLING_EVENT = threading.Event(), threading.Event()
UPDATE_EVENT.clear()  # not update now,相当于把标志位设置为False
ROLLING_EVENT.set()  # start to roll out,相当于把标志位设置为True,并通知所有处于等待阻塞状态的线程恢复运行状态。

# 创建workers
workers = [Worker(wid=i) for i in range(N_WORKER)]

GLOBAL_UPDATE_COUNTER, GLOBAL_EP = 0, 0  # 全局更新次数计数器,全局EP计数器
GLOBAL_RUNNING_R = []  # 记录动态的reward,看成绩
COORD = tf.train.Coordinator()  # 创建tensorflow的协调器
QUEUE = queue.Queue()  # workers putting data in this queue
threads = []

# 为每个worker创建进程
for worker in workers:  # worker threads
    t = threading.Thread(target=worker.work, args=())  # 创建进程
    t.start()  # 开始进程
    threads.append(t)  # 把进程放到进程列表中,方便管理

# add a PPO updating thread
# 把一个全局的PPO更新放到进程列表最后。
threads.append(threading.Thread(target=GLOBAL_PPO.update, ))
threads[-1].start()
COORD.join(threads)  # 把进程列表交给协调器管理

GLOBAL_PPO.save_ckpt()  # 保存全局参数

# plot reward change and test
plt.title('DPPO')
plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
plt.xlabel('Episode')
plt.ylabel('Moving reward')
plt.ylim(-2000, 0)
plt.show()
Episode: 1/100  | Worker: 1 | Episode Reward: -965.6343  | Running Time: 3.1675

Episode: 2/100  | Worker: 2 | Episode Reward: -1443.1138  | Running Time: 3.1689

Episode: 3/100  | Worker: 3 | Episode Reward: -1313.6248  | Running Time: 3.1734

Episode: 4/100  | Worker: 0 | Episode Reward: -1403.1952  | Running Time: 3.1819

Episode: 5/100  | Worker: 1 | Episode Reward: -1399.3963  | Running Time: 3.2429

Episode: 6/100  | Worker: 2 | Episode Reward: -1480.8439  | Running Time: 3.2453

Episode: 7/100  | Worker: 0 | Episode Reward: -1489.4195  | Running Time: 3.2373

Episode: 8/100  | Worker: 3 | Episode Reward: -1339.0517  | Running Time: 3.2583

Episode: 9/100  | Worker: 1 | Episode Reward: -1600.1292  | Running Time: 3.2478

Episode: 10/100  | Worker: 0 | Episode Reward: -1513.2170  | Running Time: 3.2584

Episode: 11/100  | Worker: 2 | Episode Reward: -1461.7279  | Running Time: 3.2697

Episode: 12/100  | Worker: 3 | Episode Reward: -1480.2685  | Running Time: 3.2598

Episode: 13/100  | Worker: 0 | Episode Reward: -1831.5374  | Running Time: 3.2423Episode: 14/100  | Worker: 1 | Episode Reward: -1524.8253  | Running Time: 3.2635



Episode: 15/100  | Worker: 2 | Episode Reward: -1383.4878  | Running Time: 3.2556

Episode: 16/100  | Worker: 3 | Episode Reward: -1288.9392  | Running Time: 3.2588

Episode: 17/100  | Worker: 1 | Episode Reward: -1657.2223  | Running Time: 3.2377

Episode: 18/100  | Worker: 0 | Episode Reward: -1472.2335  | Running Time: 3.2678

Episode: 19/100  | Worker: 2 | Episode Reward: -1475.5421  | Running Time: 3.2667

Episode: 20/100  | Worker: 3 | Episode Reward: -1532.7678  | Running Time: 3.2739

Episode: 21/100  | Worker: 1 | Episode Reward: -1575.5706  | Running Time: 3.2688

Episode: 22/100  | Worker: 2 | Episode Reward: -1238.4006  | Running Time: 3.2303

Episode: 23/100  | Worker: 0 | Episode Reward: -1630.9554  | Running Time: 3.2584

Episode: 24/100  | Worker: 3 | Episode Reward: -1610.7237  | Running Time: 3.2601

Episode: 25/100  | Worker: 1 | Episode Reward: -1516.5440  | Running Time: 3.2683

Episode: 26/100  | Worker: 0 | Episode Reward: -1547.6209  | Running Time: 3.2589

Episode: 27/100  | Worker: 2 | Episode Reward: -1328.2584  | Running Time: 3.2762

Episode: 28/100  | Worker: 3 | Episode Reward: -1191.0914  | Running Time: 3.2552

Episode: 29/100  | Worker: 1 | Episode Reward: -1415.3608  | Running Time: 3.2804

Episode: 30/100  | Worker: 0 | Episode Reward: -1765.8007  | Running Time: 3.2767

Episode: 31/100  | Worker: 2 | Episode Reward: -1756.5872  | Running Time: 3.3078

Episode: 32/100  | Worker: 3 | Episode Reward: -1428.0094  | Running Time: 3.2815

Episode: 33/100  | Worker: 1 | Episode Reward: -1605.7720  | Running Time: 3.3010

Episode: 34/100  | Worker: 0 | Episode Reward: -1247.7492  | Running Time: 3.3115Episode: 35/100  | Worker: 2 | Episode Reward: -1333.9553  | Running Time: 3.2759



Episode: 36/100  | Worker: 3 | Episode Reward: -1485.7453  | Running Time: 3.2749

Episode: 37/100  | Worker: 3 | Episode Reward: -1341.3090  | Running Time: 3.2323

Episode: 38/100  | Worker: 2 | Episode Reward: -1472.5245  | Running Time: 3.2595

Episode: 39/100  | Worker: 0 | Episode Reward: -1583.6614  | Running Time: 3.2721

Episode: 40/100  | Worker: 1 | Episode Reward: -1358.4421  | Running Time: 3.2925

Episode: 41/100  | Worker: 3 | Episode Reward: -1744.7500  | Running Time: 3.2391

Episode: 42/100  | Worker: 2 | Episode Reward: -1684.8821  | Running Time: 3.2527

Episode: 43/100  | Worker: 1 | Episode Reward: -1412.0231  | Running Time: 3.2400

Episode: 44/100  | Worker: 0 | Episode Reward: -1437.6130  | Running Time: 3.2458

Episode: 45/100  | Worker: 3 | Episode Reward: -1461.7901  | Running Time: 3.2872

Episode: 46/100  | Worker: 2 | Episode Reward: -1572.6255  | Running Time: 3.2710

Episode: 47/100  | Worker: 0 | Episode Reward: -1704.6351  | Running Time: 3.2762

Episode: 48/100  | Worker: 1 | Episode Reward: -1538.4030  | Running Time: 3.3117

Episode: 49/100  | Worker: 3 | Episode Reward: -1554.7941  | Running Time: 3.2881

Episode: 50/100  | Worker: 2 | Episode Reward: -1796.0786  | Running Time: 3.2718

Episode: 51/100  | Worker: 0 | Episode Reward: -1877.3152  | Running Time: 3.2804

Episode: 52/100  | Worker: 1 | Episode Reward: -1749.8780  | Running Time: 3.2779

Episode: 53/100  | Worker: 3 | Episode Reward: -1486.8338  | Running Time: 3.1559

Episode: 54/100  | Worker: 2 | Episode Reward: -1540.8134  | Running Time: 3.2903

Episode: 55/100  | Worker: 0 | Episode Reward: -1596.7365  | Running Time: 3.3156

Episode: 56/100  | Worker: 1 | Episode Reward: -1644.7888  | Running Time: 3.3065

Episode: 57/100  | Worker: 3 | Episode Reward: -1514.0685  | Running Time: 3.2920

Episode: 58/100  | Worker: 2 | Episode Reward: -1411.2714  | Running Time: 3.1554

Episode: 59/100  | Worker: 0 | Episode Reward: -1602.3725  | Running Time: 3.2737

Episode: 60/100  | Worker: 1 | Episode Reward: -1579.8769  | Running Time: 3.3140

Episode: 61/100  | Worker: 3 | Episode Reward: -1360.7916  | Running Time: 3.2856

Episode: 62/100  | Worker: 2 | Episode Reward: -1490.7107  | Running Time: 3.2861

Episode: 63/100  | Worker: 0 | Episode Reward: -1775.7557  | Running Time: 3.2644

Episode: 64/100  | Worker: 1 | Episode Reward: -1491.0894  | Running Time: 3.2828

Episode: 65/100  | Worker: 0 | Episode Reward: -1428.8124  | Running Time: 3.1239

Episode: 66/100  | Worker: 2 | Episode Reward: -1493.7703  | Running Time: 3.2680

Episode: 67/100  | Worker: 3 | Episode Reward: -1658.3558  | Running Time: 3.2853

Episode: 68/100  | Worker: 1 | Episode Reward: -1605.9077  | Running Time: 3.2911

Episode: 69/100  | Worker: 2 | Episode Reward: -1374.3309  | Running Time: 3.3644

Episode: 70/100  | Worker: 0 | Episode Reward: -1283.5023  | Running Time: 3.3819

Episode: 71/100  | Worker: 3 | Episode Reward: -1346.1850  | Running Time: 3.3860

Episode: 72/100  | Worker: 1 | Episode Reward: -1222.1988  | Running Time: 3.3724

Episode: 73/100  | Worker: 2 | Episode Reward: -1199.1266  | Running Time: 3.2739

Episode: 74/100  | Worker: 0 | Episode Reward: -1207.3161  | Running Time: 3.2670

Episode: 75/100  | Worker: 3 | Episode Reward: -1302.0207  | Running Time: 3.2562

Episode: 76/100  | Worker: 1 | Episode Reward: -1233.3584  | Running Time: 3.2892

Episode: 77/100  | Worker: 2 | Episode Reward: -964.8099  | Running Time: 3.2339

Episode: 78/100  | Worker: 0 | Episode Reward: -1208.2836  | Running Time: 3.2602

Episode: 79/100  | Worker: 3 | Episode Reward: -1149.2154  | Running Time: 3.2579

Episode: 80/100  | Worker: 1 | Episode Reward: -1219.3229  | Running Time: 3.2321

Episode: 81/100  | Worker: 2 | Episode Reward: -1097.7572  | Running Time: 3.2995

Episode: 82/100  | Worker: 3 | Episode Reward: -940.7949  | Running Time: 3.2981

Episode: 83/100  | Worker: 0 | Episode Reward: -1395.6272  | Running Time: 3.3076

Episode: 84/100  | Worker: 1 | Episode Reward: -1092.5180  | Running Time: 3.2936

Episode: 85/100  | Worker: 2 | Episode Reward: -1369.8868  | Running Time: 3.2517

Episode: 86/100  | Worker: 0 | Episode Reward: -1380.5247  | Running Time: 3.2390

Episode: 87/100  | Worker: 3 | Episode Reward: -1413.2114  | Running Time: 3.2740

Episode: 88/100  | Worker: 1 | Episode Reward: -1403.9904  | Running Time: 3.2643

Episode: 89/100  | Worker: 2 | Episode Reward: -1098.8470  | Running Time: 3.3078

Episode: 90/100  | Worker: 0 | Episode Reward: -983.4387  | Running Time: 3.3224

Episode: 91/100  | Worker: 3 | Episode Reward: -1056.6701  | Running Time: 3.3059

Episode: 92/100  | Worker: 1 | Episode Reward: -1357.6828  | Running Time: 3.2980

Episode: 93/100  | Worker: 2 | Episode Reward: -1082.3377  | Running Time: 3.3248

Episode: 94/100  | Worker: 3 | Episode Reward: -1052.0146  | Running Time: 3.3291

Episode: 95/100  | Worker: 0 | Episode Reward: -1373.0590  | Running Time: 3.3660

Episode: 96/100  | Worker: 1 | Episode Reward: -1044.4578  | Running Time: 3.3311

Episode: 97/100  | Worker: 2 | Episode Reward: -1179.2926  | Running Time: 3.3593

Episode: 98/100  | Worker: 3 | Episode Reward: -1039.1825  | Running Time: 3.3540

Episode: 99/100  | Worker: 0 | Episode Reward: -1193.3356  | Running Time: 3.3599

Episode: 100/100  | Worker: 1 | Episode Reward: -1378.5094  | Running Time: 3.2025

Episode: 101/100  | Worker: 2 | Episode Reward: -30.6317  | Running Time: 0.1128

Episode: 102/100  | Worker: 0 | Episode Reward: -141.0568  | Running Time: 0.2976

Episode: 103/100  | Worker: 3 | Episode Reward: -166.4818  | Running Time: 0.3256

Episode: 104/100  | Worker: 1 | Episode Reward: -123.2953  | Running Time: 0.2683

[TL] [*] Saving TL weights into model_Pendulum/dppo_actor.hdf5

[TL] [*] Saved

[TL] [*] Saving TL weights into model_Pendulum/dppo_actor_old.hdf5

[TL] [*] Saved

[TL] [*] Saving TL weights into model_Pendulum/dppo_critic.hdf5

[TL] [*] Saved

6. 模型推理

Notebook暂时不支持Pendulum可视化,请将下面代码下载到本地,可查看可视化效果。

from matplotlib import animation

GLOBAL_PPO.load_ckpt()
env = gym.make(env_name)
s = env.reset()

def display_frames_as_gif(frames):
    patch = plt.imshow(frames[0])
    plt.axis('off')

    def animate(i):
        patch.set_data(frames[i])

    anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
    anim.save('./DPPO_Pendulum.gif', writer='imagemagick', fps=30)

total_reward = 0
frames = []

while True:
    env.render()
    frames.append(env.render(mode='rgb_array'))
    s, r, done, info = env.step(GLOBAL_PPO.choose_action(s))
    if done:
        print('It is over, the window will be closed after 1 seconds.')
        time.sleep(1)
        break
env.close()
print('Total Reward : %.2f' % total_reward)
display_frames_as_gif(frames)

7. 模型推理效果

如下视频是训练1000 Episode模型的推理效果

8. 作业

  1. 请你调整步骤2中的训练参数,重新训练一个模型,使它在游戏中获得更好的表现
【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。