使用DPPO算法控制“倒立摆”
使用DPPO算法控制“倒立摆”
实验目标
通过本案例的学习和课后作业的练习:
- 了解DPPO基本概念
- 了解如何基于DPPO训练一个控制类问题
- 了解强化学习训练推理控制类问题的整体流程
你也可以将本案例相关的 ipynb 学习笔记分享到 AI Gallery Notebook 版块获得成长值,分享方法请查看此文档。
案例内容介绍
倒立摆(Pendulum)摆动问题是控制文献中的经典问题。在我们本节用DPPO解决的Pendulum-v0问题中,钟摆从一个随机位置开始围绕一个端点摆动,目标是把钟摆向上摆动,并且是钟摆保持直立。一个随机动作的倒立摆demo如下所示:
整体流程:安装基础依赖->创建倒立摆环境->构建DPPO算法->训练->推理->可视化效果
Distributed Proximal Policy Optimization (DPPO) 算法的基本结构
DPPO算法是在Proximal Policy Optimization(PPO)算法基础上发展而来,相关PPO算法请看 使用PPO算法玩“超级马里奥兄弟”,我们在这一教程中有详细的介绍。DPPO借鉴A3C的并行方法,使用多个workers并行地在不同的环境中收集数据,并根据采集的数据计算梯度,将梯度发送给一个全局chief,全局chief在拿到一定数量的梯度数据之后进行网络更新,更新时workers停止采集等待等下完毕,更新完毕之后workers重新使用最新的网络采集数据。
下面我们使用论文中的伪代码介绍DPPO的具体算法细节。
上述算法所示为全局PPO的伪代码,其中 W 是workers的数目,D 是用于更新参数的works数量阈值,M,B是给定一批数据点的具有policy网络和critic网络更新的子迭代数,θ, Φ为policy网络,critic网络的参数。
上述算法所示为workers的伪代码,其中T是在计算参数更新之前收集的每个工作节点的数据点数,K是计算K步返回和通过时间截断的反向道具的时间步数(对于RNN)。 该部分算法基于PPO,首先采集数据,根据PPO算法计算梯度并将梯度发送给全局chief,等待全局chief更新完毕参数再进行数据的采集。
注意事项
-
本案例运行环境为 TensorFlow-2.0.0,且需使用 GPU 运行,请查看《ModelAtrs JupyterLab 硬件规格使用指南》了解切换硬件规格的方法;
-
如果您是第一次使用 JupyterLab,请查看《ModelAtrs JupyterLab使用指导》了解使用方法;
-
如果您在使用 JupyterLab 过程中碰到报错,请参考《ModelAtrs JupyterLab常见问题解决办法》尝试解决问题。
实验步骤
1. 程序初始化
第1步:安装基础依赖
!pip install tensorflow==2.0.0
!pip install tensorflow-probability==0.7.0
!pip install tensorlayer==2.1.0 --ignore-installed
!pip install h5py==2.10.0
!pip install gym
第2步:导入相关的库
import os
import time
import queue
import threading
import gym
import matplotlib.pyplot as plt
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tensorlayer as tl
2. 训练参数初始化
本案例设置的 训练最大局数 EP_MAX = 1000,可以达到较好的训练效果,训练耗时约20分钟。
你也可以调小 EP_MAX 的值,以便快速跑通代码。
RANDOMSEED = 1 # 随机数种子
EP_MAX = 100 # 训练总局数
EP_LEN = 200 # 一局最长长度
GAMMA = 0.9 # 折扣率
A_LR = 0.0001 # actor学习率
C_LR = 0.0002 # critic学习率
BATCH = 32 # batchsize大小
A_UPDATE_STEPS = 10 # actor更新步数
C_UPDATE_STEPS = 10 # critic更新步数
S_DIM, A_DIM = 3, 1 # state维度, action维度
EPS = 1e-8 # epsilon值
# PPO1 和PPO2 的参数,你可以选择用PPO1 (METHOD[0]),还是PPO2 (METHOD[1])
METHOD = [
dict(name='kl_pen', kl_target=0.01, lam=0.5), # KL penalty
dict(name='clip', epsilon=0.2), # Clipped surrogate objective, find this is better
][1] # choose the method for optimization
N_WORKER = 4 # 并行workers数目
MIN_BATCH_SIZE = 64 # 更新PPO的minibatch大小
UPDATE_STEP = 10 # 每隔10steps更新一次
3. 创建环境
本环境为gym内置的Pendulum,倒立摆倒下即失败。
env_name = 'Pendulum-v0' # environment name
4. 定义DPPO算法
DPPO算法-PPO算法
class PPO(object):
'''
PPO class
'''
def __init__(self):
# 创建critic
tfs = tl.layers.Input([None, S_DIM], tf.float32, 'state')
l1 = tl.layers.Dense(100, tf.nn.relu)(tfs)
v = tl.layers.Dense(1)(l1)
self.critic = tl.models.Model(tfs, v)
self.critic.train()
# 创建actor
self.actor = self._build_anet('pi', trainable=True)
self.actor_old = self._build_anet('oldpi', trainable=False)
self.actor_opt = tf.optimizers.Adam(A_LR)
self.critic_opt = tf.optimizers.Adam(C_LR)
# 更新actor
def a_train(self, tfs, tfa, tfadv):
'''
Update policy network
:param tfs: state
:param tfa: act
:param tfadv: advantage
:return:
'''
tfs = np.array(tfs, np.float32)
tfa = np.array(tfa, np.float32)
tfadv = np.array(tfadv, np.float32) # td-error
with tf.GradientTape() as tape:
mu, sigma = self.actor(tfs)
pi = tfp.distributions.Normal(mu, sigma)
mu_old, sigma_old = self.actor_old(tfs)
oldpi = tfp.distributions.Normal(mu_old, sigma_old)
ratio = pi.prob(tfa) / (oldpi.prob(tfa) + EPS)
surr = ratio * tfadv
## PPO1
if METHOD['name'] == 'kl_pen':
tflam = METHOD['lam']
kl = tfp.distributions.kl_divergence(oldpi, pi)
kl_mean = tf.reduce_mean(kl)
aloss = -(tf.reduce_mean(surr - tflam * kl))
## PPO2
else:
aloss = -tf.reduce_mean(
tf.minimum(surr,
tf.clip_by_value(ratio, 1. - METHOD['epsilon'], 1. + METHOD['epsilon']) * tfadv)
)
a_gard = tape.gradient(aloss, self.actor.trainable_weights)
self.actor_opt.apply_gradients(zip(a_gard, self.actor.trainable_weights))
if METHOD['name'] == 'kl_pen':
return kl_mean
# 更新old_pi
def update_old_pi(self):
'''
Update old policy parameter
:return: None
'''
for p, oldp in zip(self.actor.trainable_weights, self.actor_old.trainable_weights):
oldp.assign(p)
# 更新critic
def c_train(self, tfdc_r, s):
'''
Update actor network
:param tfdc_r: cumulative reward
:param s: state
:return: None
'''
tfdc_r = np.array(tfdc_r, dtype=np.float32)
with tf.GradientTape() as tape:
advantage = tfdc_r - self.critic(s) # 计算advantage:V(s') * gamma + r - V(s)
closs = tf.reduce_mean(tf.square(advantage))
grad = tape.gradient(closs, self.critic.trainable_weights)
self.critic_opt.apply_gradients(zip(grad, self.critic.trainable_weights))
# 计算advantage:V(s') * gamma + r - V(s)
def cal_adv(self, tfs, tfdc_r):
'''
Calculate advantage
:param tfs: state
:param tfdc_r: cumulative reward
:return: advantage
'''
tfdc_r = np.array(tfdc_r, dtype=np.float32)
advantage = tfdc_r - self.critic(tfs)
return advantage.numpy()
def update(self):
'''
Update parameter with the constraint of KL divergent
:return: None
'''
global GLOBAL_UPDATE_COUNTER
while not COORD.should_stop(): # 如果协调器没有停止
if GLOBAL_EP < EP_MAX: # EP_MAX是最大更新次数
UPDATE_EVENT.wait() # PPO进程的等待位置
self.update_old_pi() # copy pi to old pi
data = [QUEUE.get() for _ in range(QUEUE.qsize())] # collect data from all workers
data = np.vstack(data)
s, a, r = data[:, :S_DIM].astype(np.float32), \
data[:, S_DIM: S_DIM + A_DIM].astype(np.float32), \
data[:, -1:].astype(np.float32)
adv = self.cal_adv(s, r)
# update actor
## PPO1
if METHOD['name'] == 'kl_pen':
for _ in range(A_UPDATE_STEPS):
kl = self.a_train(s, a, adv)
if kl > 4 * METHOD['kl_target']: # this in in google's paper
break
if kl < METHOD['kl_target'] / 1.5: # adaptive lambda, this is in OpenAI's paper
METHOD['lam'] /= 2
elif kl > METHOD['kl_target'] * 1.5:
METHOD['lam'] *= 2
# sometimes explode, this clipping is MorvanZhou's solution
METHOD['lam'] = np.clip(METHOD['lam'], 1e-4, 10)
## PPO2
else: # clipping method, find this is better (OpenAI's paper)
for _ in range(A_UPDATE_STEPS):
self.a_train(s, a, adv)
# 更新critic
for _ in range(C_UPDATE_STEPS):
self.c_train(r, s)
UPDATE_EVENT.clear() # updating finished
GLOBAL_UPDATE_COUNTER = 0 # reset counter
ROLLING_EVENT.set() # set roll-out available
# 构建actor网络
def _build_anet(self, name, trainable):
'''
Build policy network
:param name: name
:param trainable: trainable flag
:return: policy network
'''
tfs = tl.layers.Input([None, S_DIM], tf.float32, name + '_state')
l1 = tl.layers.Dense(100, tf.nn.relu, name=name + '_l1')(tfs)
a = tl.layers.Dense(A_DIM, tf.nn.tanh, name=name + '_a')(l1)
mu = tl.layers.Lambda(lambda x: x * 2, name=name + '_lambda')(a)
sigma = tl.layers.Dense(A_DIM, tf.nn.softplus, name=name + '_sigma')(l1)
model = tl.models.Model(tfs, [mu, sigma], name)
if trainable:
model.train()
else:
model.eval()
return model
# 选择动作
def choose_action(self, s):
'''
Choose action
:param s: state
:return: clipped act
'''
s = s[np.newaxis, :].astype(np.float32)
mu, sigma = self.actor(s)
pi = tfp.distributions.Normal(mu, sigma)
a = tf.squeeze(pi.sample(1), axis=0)[0] # choosing action
return np.clip(a, -2, 2)
# 计算V()
def get_v(self, s):
'''
Compute value
:param s: state
:return: value
'''
s = s.astype(np.float32)
if s.ndim < 2: s = s[np.newaxis, :]
return self.critic(s)[0, 0]
def save_ckpt(self):
"""
save trained weights
:return: None
"""
if not os.path.exists('model_Pendulum'):
os.makedirs('model_Pendulum')
tl.files.save_weights_to_hdf5('model_Pendulum/dppo_actor.hdf5', self.actor)
tl.files.save_weights_to_hdf5('model_Pendulum/dppo_actor_old.hdf5', self.actor_old)
tl.files.save_weights_to_hdf5('model_Pendulum/dppo_critic.hdf5', self.critic)
def load_ckpt(self):
"""
load trained weights
:return: None
"""
tl.files.load_hdf5_to_weights_in_order('model_Pendulum/dppo_actor.hdf5', self.actor)
tl.files.load_hdf5_to_weights_in_order('model_Pendulum/dppo_actor_old.hdf5', self.actor_old)
tl.files.load_hdf5_to_weights_in_order('model_Pendulum/dppo_critic.hdf5', self.critic)
workers构建
class Worker(object):
'''
Worker class for distributional running
'''
def __init__(self, wid):
self.wid = wid # 工号
self.env = gym.make(env_name).unwrapped # 创建环境
self.env.seed(wid * 100 + RANDOMSEED) # 设置不同的随机种子,因为不希望每个worker的都一致
self.ppo = GLOBAL_PPO # 算法
def work(self):
'''
Define a worker
:return: None
'''
global GLOBAL_EP, GLOBAL_RUNNING_R, GLOBAL_UPDATE_COUNTER
while not COORD.should_stop(): # 从COORD接受消息,看看是否应该should_stop
s = self.env.reset()
ep_r = 0
buffer_s, buffer_a, buffer_r = [], [], [] # 记录data
t0 = time.time()
for t in range(EP_LEN):
# 看是否正在被更新。PPO进程正在工作,那么就在这里等待
if not ROLLING_EVENT.is_set(): # 查询进程是否被阻塞,如果在阻塞状态,就证明如果global PPO正在更新。否则就可以继续。
ROLLING_EVENT.wait() # worker进程的等待位置。wait until PPO is updated
buffer_s, buffer_a, buffer_r = [], [], [] # clear history buffer, use new policy to collect data
# 正常跑游戏,并搜集数据
a = self.ppo.choose_action(s)
s_, r, done, _ = self.env.step(a)
buffer_s.append(s)
buffer_a.append(a)
buffer_r.append((r + 8) / 8) # normalize reward, find to be useful
s = s_
ep_r += r
# GLOBAL_UPDATE_COUNTER是每个work的在游戏中进行一步,也就是产生一条数据就会+1.
# 当GLOBAL_UPDATE_COUNTER大于batch(64)的时候,就可以进行更新。
GLOBAL_UPDATE_COUNTER += 1 # count to minimum batch size, no need to wait other workers
if t == EP_LEN - 1 or GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE: # t == EP_LEN - 1 是最后一步
## 计算每个状态对应的V(s')
## 要注意,这里的len(buffer) < GLOBAL_UPDATE_COUNTER。所以数据是每个worker各自计算的。
v_s_ = self.ppo.get_v(s_)
discounted_r = [] # compute discounted reward
for r in buffer_r[::-1]:
v_s_ = r + GAMMA * v_s_
discounted_r.append(v_s_)
discounted_r.reverse()
## 堆叠成数据,并保存到公共队列中。
bs, ba, br = np.vstack(buffer_s), np.vstack(buffer_a), np.array(discounted_r)[:, np.newaxis]
buffer_s, buffer_a, buffer_r = [], [], []
QUEUE.put(np.hstack((bs, ba, br))) # put data in the queue
# 如果数据足够,就开始更新
if GLOBAL_UPDATE_COUNTER >= MIN_BATCH_SIZE:
ROLLING_EVENT.clear() # stop collecting data
UPDATE_EVENT.set() # global PPO update
if GLOBAL_EP >= EP_MAX: # stop training
COORD.request_stop() # 停止更新
break
# record reward changes, plot later
if len(GLOBAL_RUNNING_R) == 0:
GLOBAL_RUNNING_R.append(ep_r)
else:
GLOBAL_RUNNING_R.append(GLOBAL_RUNNING_R[-1] * 0.9 + ep_r * 0.1)
GLOBAL_EP += 1
print(
'Episode: {}/{} | Worker: {} | Episode Reward: {:.4f} | Running Time: {:.4f}'.format(
GLOBAL_EP, EP_MAX, self.wid, ep_r,
time.time() - t0
)
)
5. 模型训练
np.random.seed(RANDOMSEED)
tf.random.set_seed(RANDOMSEED)
GLOBAL_PPO = PPO()
[TL] Input state: [None, 3]
[TL] Dense dense_1: 100 relu
[TL] Dense dense_2: 1 No Activation
[TL] Input pi_state: [None, 3]
[TL] Dense pi_l1: 100 relu
[TL] Dense pi_a: 1 tanh
[TL] Lambda pi_lambda: func: <function PPO._build_anet.<locals>.<lambda> at 0x7fb65d633950>, len_weights: 0
[TL] Dense pi_sigma: 1 softplus
[TL] Input oldpi_state: [None, 3]
[TL] Dense oldpi_l1: 100 relu
[TL] Dense oldpi_a: 1 tanh
[TL] Lambda oldpi_lambda: func: <function PPO._build_anet.<locals>.<lambda> at 0x7fb65d633a70>, len_weights: 0
[TL] Dense oldpi_sigma: 1 softplus
# 定义两组不同的事件,update 和 rolling
UPDATE_EVENT, ROLLING_EVENT = threading.Event(), threading.Event()
UPDATE_EVENT.clear() # not update now,相当于把标志位设置为False
ROLLING_EVENT.set() # start to roll out,相当于把标志位设置为True,并通知所有处于等待阻塞状态的线程恢复运行状态。
# 创建workers
workers = [Worker(wid=i) for i in range(N_WORKER)]
GLOBAL_UPDATE_COUNTER, GLOBAL_EP = 0, 0 # 全局更新次数计数器,全局EP计数器
GLOBAL_RUNNING_R = [] # 记录动态的reward,看成绩
COORD = tf.train.Coordinator() # 创建tensorflow的协调器
QUEUE = queue.Queue() # workers putting data in this queue
threads = []
# 为每个worker创建进程
for worker in workers: # worker threads
t = threading.Thread(target=worker.work, args=()) # 创建进程
t.start() # 开始进程
threads.append(t) # 把进程放到进程列表中,方便管理
# add a PPO updating thread
# 把一个全局的PPO更新放到进程列表最后。
threads.append(threading.Thread(target=GLOBAL_PPO.update, ))
threads[-1].start()
COORD.join(threads) # 把进程列表交给协调器管理
GLOBAL_PPO.save_ckpt() # 保存全局参数
# plot reward change and test
plt.title('DPPO')
plt.plot(np.arange(len(GLOBAL_RUNNING_R)), GLOBAL_RUNNING_R)
plt.xlabel('Episode')
plt.ylabel('Moving reward')
plt.ylim(-2000, 0)
plt.show()
Episode: 1/100 | Worker: 1 | Episode Reward: -965.6343 | Running Time: 3.1675
Episode: 2/100 | Worker: 2 | Episode Reward: -1443.1138 | Running Time: 3.1689
Episode: 3/100 | Worker: 3 | Episode Reward: -1313.6248 | Running Time: 3.1734
Episode: 4/100 | Worker: 0 | Episode Reward: -1403.1952 | Running Time: 3.1819
Episode: 5/100 | Worker: 1 | Episode Reward: -1399.3963 | Running Time: 3.2429
Episode: 6/100 | Worker: 2 | Episode Reward: -1480.8439 | Running Time: 3.2453
Episode: 7/100 | Worker: 0 | Episode Reward: -1489.4195 | Running Time: 3.2373
Episode: 8/100 | Worker: 3 | Episode Reward: -1339.0517 | Running Time: 3.2583
Episode: 9/100 | Worker: 1 | Episode Reward: -1600.1292 | Running Time: 3.2478
Episode: 10/100 | Worker: 0 | Episode Reward: -1513.2170 | Running Time: 3.2584
Episode: 11/100 | Worker: 2 | Episode Reward: -1461.7279 | Running Time: 3.2697
Episode: 12/100 | Worker: 3 | Episode Reward: -1480.2685 | Running Time: 3.2598
Episode: 13/100 | Worker: 0 | Episode Reward: -1831.5374 | Running Time: 3.2423Episode: 14/100 | Worker: 1 | Episode Reward: -1524.8253 | Running Time: 3.2635
Episode: 15/100 | Worker: 2 | Episode Reward: -1383.4878 | Running Time: 3.2556
Episode: 16/100 | Worker: 3 | Episode Reward: -1288.9392 | Running Time: 3.2588
Episode: 17/100 | Worker: 1 | Episode Reward: -1657.2223 | Running Time: 3.2377
Episode: 18/100 | Worker: 0 | Episode Reward: -1472.2335 | Running Time: 3.2678
Episode: 19/100 | Worker: 2 | Episode Reward: -1475.5421 | Running Time: 3.2667
Episode: 20/100 | Worker: 3 | Episode Reward: -1532.7678 | Running Time: 3.2739
Episode: 21/100 | Worker: 1 | Episode Reward: -1575.5706 | Running Time: 3.2688
Episode: 22/100 | Worker: 2 | Episode Reward: -1238.4006 | Running Time: 3.2303
Episode: 23/100 | Worker: 0 | Episode Reward: -1630.9554 | Running Time: 3.2584
Episode: 24/100 | Worker: 3 | Episode Reward: -1610.7237 | Running Time: 3.2601
Episode: 25/100 | Worker: 1 | Episode Reward: -1516.5440 | Running Time: 3.2683
Episode: 26/100 | Worker: 0 | Episode Reward: -1547.6209 | Running Time: 3.2589
Episode: 27/100 | Worker: 2 | Episode Reward: -1328.2584 | Running Time: 3.2762
Episode: 28/100 | Worker: 3 | Episode Reward: -1191.0914 | Running Time: 3.2552
Episode: 29/100 | Worker: 1 | Episode Reward: -1415.3608 | Running Time: 3.2804
Episode: 30/100 | Worker: 0 | Episode Reward: -1765.8007 | Running Time: 3.2767
Episode: 31/100 | Worker: 2 | Episode Reward: -1756.5872 | Running Time: 3.3078
Episode: 32/100 | Worker: 3 | Episode Reward: -1428.0094 | Running Time: 3.2815
Episode: 33/100 | Worker: 1 | Episode Reward: -1605.7720 | Running Time: 3.3010
Episode: 34/100 | Worker: 0 | Episode Reward: -1247.7492 | Running Time: 3.3115Episode: 35/100 | Worker: 2 | Episode Reward: -1333.9553 | Running Time: 3.2759
Episode: 36/100 | Worker: 3 | Episode Reward: -1485.7453 | Running Time: 3.2749
Episode: 37/100 | Worker: 3 | Episode Reward: -1341.3090 | Running Time: 3.2323
Episode: 38/100 | Worker: 2 | Episode Reward: -1472.5245 | Running Time: 3.2595
Episode: 39/100 | Worker: 0 | Episode Reward: -1583.6614 | Running Time: 3.2721
Episode: 40/100 | Worker: 1 | Episode Reward: -1358.4421 | Running Time: 3.2925
Episode: 41/100 | Worker: 3 | Episode Reward: -1744.7500 | Running Time: 3.2391
Episode: 42/100 | Worker: 2 | Episode Reward: -1684.8821 | Running Time: 3.2527
Episode: 43/100 | Worker: 1 | Episode Reward: -1412.0231 | Running Time: 3.2400
Episode: 44/100 | Worker: 0 | Episode Reward: -1437.6130 | Running Time: 3.2458
Episode: 45/100 | Worker: 3 | Episode Reward: -1461.7901 | Running Time: 3.2872
Episode: 46/100 | Worker: 2 | Episode Reward: -1572.6255 | Running Time: 3.2710
Episode: 47/100 | Worker: 0 | Episode Reward: -1704.6351 | Running Time: 3.2762
Episode: 48/100 | Worker: 1 | Episode Reward: -1538.4030 | Running Time: 3.3117
Episode: 49/100 | Worker: 3 | Episode Reward: -1554.7941 | Running Time: 3.2881
Episode: 50/100 | Worker: 2 | Episode Reward: -1796.0786 | Running Time: 3.2718
Episode: 51/100 | Worker: 0 | Episode Reward: -1877.3152 | Running Time: 3.2804
Episode: 52/100 | Worker: 1 | Episode Reward: -1749.8780 | Running Time: 3.2779
Episode: 53/100 | Worker: 3 | Episode Reward: -1486.8338 | Running Time: 3.1559
Episode: 54/100 | Worker: 2 | Episode Reward: -1540.8134 | Running Time: 3.2903
Episode: 55/100 | Worker: 0 | Episode Reward: -1596.7365 | Running Time: 3.3156
Episode: 56/100 | Worker: 1 | Episode Reward: -1644.7888 | Running Time: 3.3065
Episode: 57/100 | Worker: 3 | Episode Reward: -1514.0685 | Running Time: 3.2920
Episode: 58/100 | Worker: 2 | Episode Reward: -1411.2714 | Running Time: 3.1554
Episode: 59/100 | Worker: 0 | Episode Reward: -1602.3725 | Running Time: 3.2737
Episode: 60/100 | Worker: 1 | Episode Reward: -1579.8769 | Running Time: 3.3140
Episode: 61/100 | Worker: 3 | Episode Reward: -1360.7916 | Running Time: 3.2856
Episode: 62/100 | Worker: 2 | Episode Reward: -1490.7107 | Running Time: 3.2861
Episode: 63/100 | Worker: 0 | Episode Reward: -1775.7557 | Running Time: 3.2644
Episode: 64/100 | Worker: 1 | Episode Reward: -1491.0894 | Running Time: 3.2828
Episode: 65/100 | Worker: 0 | Episode Reward: -1428.8124 | Running Time: 3.1239
Episode: 66/100 | Worker: 2 | Episode Reward: -1493.7703 | Running Time: 3.2680
Episode: 67/100 | Worker: 3 | Episode Reward: -1658.3558 | Running Time: 3.2853
Episode: 68/100 | Worker: 1 | Episode Reward: -1605.9077 | Running Time: 3.2911
Episode: 69/100 | Worker: 2 | Episode Reward: -1374.3309 | Running Time: 3.3644
Episode: 70/100 | Worker: 0 | Episode Reward: -1283.5023 | Running Time: 3.3819
Episode: 71/100 | Worker: 3 | Episode Reward: -1346.1850 | Running Time: 3.3860
Episode: 72/100 | Worker: 1 | Episode Reward: -1222.1988 | Running Time: 3.3724
Episode: 73/100 | Worker: 2 | Episode Reward: -1199.1266 | Running Time: 3.2739
Episode: 74/100 | Worker: 0 | Episode Reward: -1207.3161 | Running Time: 3.2670
Episode: 75/100 | Worker: 3 | Episode Reward: -1302.0207 | Running Time: 3.2562
Episode: 76/100 | Worker: 1 | Episode Reward: -1233.3584 | Running Time: 3.2892
Episode: 77/100 | Worker: 2 | Episode Reward: -964.8099 | Running Time: 3.2339
Episode: 78/100 | Worker: 0 | Episode Reward: -1208.2836 | Running Time: 3.2602
Episode: 79/100 | Worker: 3 | Episode Reward: -1149.2154 | Running Time: 3.2579
Episode: 80/100 | Worker: 1 | Episode Reward: -1219.3229 | Running Time: 3.2321
Episode: 81/100 | Worker: 2 | Episode Reward: -1097.7572 | Running Time: 3.2995
Episode: 82/100 | Worker: 3 | Episode Reward: -940.7949 | Running Time: 3.2981
Episode: 83/100 | Worker: 0 | Episode Reward: -1395.6272 | Running Time: 3.3076
Episode: 84/100 | Worker: 1 | Episode Reward: -1092.5180 | Running Time: 3.2936
Episode: 85/100 | Worker: 2 | Episode Reward: -1369.8868 | Running Time: 3.2517
Episode: 86/100 | Worker: 0 | Episode Reward: -1380.5247 | Running Time: 3.2390
Episode: 87/100 | Worker: 3 | Episode Reward: -1413.2114 | Running Time: 3.2740
Episode: 88/100 | Worker: 1 | Episode Reward: -1403.9904 | Running Time: 3.2643
Episode: 89/100 | Worker: 2 | Episode Reward: -1098.8470 | Running Time: 3.3078
Episode: 90/100 | Worker: 0 | Episode Reward: -983.4387 | Running Time: 3.3224
Episode: 91/100 | Worker: 3 | Episode Reward: -1056.6701 | Running Time: 3.3059
Episode: 92/100 | Worker: 1 | Episode Reward: -1357.6828 | Running Time: 3.2980
Episode: 93/100 | Worker: 2 | Episode Reward: -1082.3377 | Running Time: 3.3248
Episode: 94/100 | Worker: 3 | Episode Reward: -1052.0146 | Running Time: 3.3291
Episode: 95/100 | Worker: 0 | Episode Reward: -1373.0590 | Running Time: 3.3660
Episode: 96/100 | Worker: 1 | Episode Reward: -1044.4578 | Running Time: 3.3311
Episode: 97/100 | Worker: 2 | Episode Reward: -1179.2926 | Running Time: 3.3593
Episode: 98/100 | Worker: 3 | Episode Reward: -1039.1825 | Running Time: 3.3540
Episode: 99/100 | Worker: 0 | Episode Reward: -1193.3356 | Running Time: 3.3599
Episode: 100/100 | Worker: 1 | Episode Reward: -1378.5094 | Running Time: 3.2025
Episode: 101/100 | Worker: 2 | Episode Reward: -30.6317 | Running Time: 0.1128
Episode: 102/100 | Worker: 0 | Episode Reward: -141.0568 | Running Time: 0.2976
Episode: 103/100 | Worker: 3 | Episode Reward: -166.4818 | Running Time: 0.3256
Episode: 104/100 | Worker: 1 | Episode Reward: -123.2953 | Running Time: 0.2683
[TL] [*] Saving TL weights into model_Pendulum/dppo_actor.hdf5
[TL] [*] Saved
[TL] [*] Saving TL weights into model_Pendulum/dppo_actor_old.hdf5
[TL] [*] Saved
[TL] [*] Saving TL weights into model_Pendulum/dppo_critic.hdf5
[TL] [*] Saved
6. 模型推理
Notebook暂时不支持Pendulum可视化,请将下面代码下载到本地,可查看可视化效果。
from matplotlib import animation
GLOBAL_PPO.load_ckpt()
env = gym.make(env_name)
s = env.reset()
def display_frames_as_gif(frames):
patch = plt.imshow(frames[0])
plt.axis('off')
def animate(i):
patch.set_data(frames[i])
anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
anim.save('./DPPO_Pendulum.gif', writer='imagemagick', fps=30)
total_reward = 0
frames = []
while True:
env.render()
frames.append(env.render(mode='rgb_array'))
s, r, done, info = env.step(GLOBAL_PPO.choose_action(s))
if done:
print('It is over, the window will be closed after 1 seconds.')
time.sleep(1)
break
env.close()
print('Total Reward : %.2f' % total_reward)
display_frames_as_gif(frames)
7. 模型推理效果
如下视频是训练1000 Episode模型的推理效果
8. 作业
- 点赞
- 收藏
- 关注作者
评论(0)