当前位置:网站首页>Double Q-Learning理论基础及其代码实现【Pendulum-v0】
Double Q-Learning理论基础及其代码实现【Pendulum-v0】
2022-07-17 00:13:00 【lucky-wz】
DQL 理论基础
为了方便,本文中Q-Learning算法记作QL,Double Q-Learning算法记作DQL。
首先,可能很多人都或多或少的听到QL相关算法通常会过高的估计在特定条件下的动作值。事实上过估计存在一定的风险,比如Hado van Hasselt,Arthur Guez和David Silver在论文《Deep Reinforcement Learning with Double Q-learning》指出 DQN 算法,的确存在特定动作在运行 Atari 2600 时会遭受严重的高估,从而会极大的影响算法的性能。而本文的主角DQL算法可以很好的降低观测到的过高估计动作的问题,而且在一些游戏上取得了更好的效果。
强化学习的目标是通过优化累积的未来奖励信号来学习序贯决策问题。 QL算法无疑是最受欢迎的强化学习算法之一,但众所周知,它有时会学习不切合实际的高动作值,因为它包含了一个超过估计动作值的最大化步骤,这往往更倾向于有个过高估值的问题。
那么什么是过估计(overestimate)呢?过估计是指对一系列数先求最大值再求平均,通常比先求平均再求最大值要大。数学表达式为
E ( max ( X 1 , X 2 , … ) ) ≥ max ( E ( X 1 ) , E ( X 2 ) , … ) E(\max (X_1, X_2, \ldots)) \geq \max (E(X_1), E(X_2), \ldots) E(max(X1,X2,…))≥max(E(X1),E(X2),…)
一般来说QL方法导致过估计的原因主要归结于其更新过程,其表达为:
Q t + 1 ( s t , a t ) = Q t ( s t , a t ) + α t ( s t , a t ) ( r t + γ max a Q t ( s t + 1 , a ) − Q t ( s t , a t ) ) Q_{t+1}\left(s_{t}, a_{t}\right)=Q_{t}\left(s_{t}, a_{t}\right)+\alpha_{t}\left(s_{t}, a_{t}\right)\left(r_{t}+\gamma \max _{a} Q_{t}\left(s_{t+1}, a\right)-Q_{t}\left(s_{t}, a_{t}\right)\right) Qt+1(st,at)=Qt(st,at)+αt(st,at)(rt+γamaxQt(st+1,a)−Qt(st,at))
其中的 max a \max _{a} maxa表示为最大化动作价值函数,而更新最优化过程如下:
∀ s , a : Q ∗ ( s , a ) = ∑ s ′ P s a s ′ ( R s a s ′ + γ max a Q ∗ ( s ′ , a ) ) \forall s, a: Q^{*}(s, a)=\sum_{s^{\prime}} P_{s a}^{s^{\prime}}\left(R_{s a}^{s^{\prime}}+\gamma \max _{a} Q^{*}\left(s^{\prime}, a\right)\right) ∀s,a:Q∗(s,a)=s′∑Psas′(Rsas′+γamaxQ∗(s′,a))
对于任意的 s s s和 a a a来说,最优值函数 Q ∗ Q^{*} Q∗的更新依赖于 max a Q ∗ ( s , … ) \max _{a} Q^{*}(s, \ldots) maxaQ∗(s,…)。从公式中可以看出,我们把 N N N个 Q Q Q值先通过取 max \max max操作之后,然后求平均,会比我们先算出 N N N个 Q Q Q值取了期望之后再 m a x max max要大。这就是过高估计的原因。
在QL中我们让目标策略的动作为当前状态下动作值函数取得最大的动作,**而这个最大化操作会导致严重的正向偏差,我们称之为最大化偏差。**怎么理解这个正向偏差呢?假设对于一个状态 s s s来说,有很多个动作 s s s可以选择。而每个 ( s , a ) (s,a) (s,a)真实的值 Q ( s , a ) Q(s,a) Q(s,a)都为0。但是由于估计偏差或者不确定性导致估计的值 Q ( s , a ) Q(s,a) Q(s,a)要么大于0,要么小于0。那么对估计值做最大化操作后,就得到了一个正值,显然相对于真实的值0,这是一个正向偏差。
我们再考察一下QL算法,为了得到更新目标 R t + 1 + γ max a Q ( S t + 1 , a ) R_{t+1}+\gamma \max _{a} Q\left(S_{t+1}, a\right) Rt+1+γmaxaQ(St+1,a) ,我们需要已知两个条件:
- 真实的 Q ( S t + 1 , ⋅ ) Q(S_{t+1}, \cdot) Q(St+1,⋅)
- 哪个动作 a a a使得 Q ( S t + 1 , ⋅ ) Q(S_{t+1}, \cdot) Q(St+1,⋅)最大。
在QL中,我们使用了相同的数据来估计这两个条件,这相当于在已有最大化偏差 Q ( S t + 1 , ⋅ ) Q(S_{t+1}, \cdot) Q(St+1,⋅)的基础上又做了最大化操作。基准都可能是错的,再找最大化的动作就没什么意义了。所以我们要把这两个过程分开,这就是DQL背后的想法。在DQL中,我们同时估计两个值 Q A ( a ) Q_A(a) QA(a)和 Q B ( a ) Q_B(a) QB(a),然后我们可以用其中一个估计来决定最大化动作,比如 A ∗ = arg max a Q A ( a ) A^{*}=\arg \max _{a} Q_{A}(a) A∗=argmaxaQA(a),用另一个估计 Q B Q_B QB来决定状态的值 Q B ( A ∗ ) = Q B ( arg max a Q A ( a ) ) Q_{B}\left(A^{*}\right)=Q_{B}\left(\arg \max _{a} Q_{A}(a)\right) QB(A∗)=QB(argmaxaQA(a))。这样就无偏了。这就是DQL。值得注意的是,尽管我们有两个估计,但是在一个时间步只会更新一个估计。什么意思呢?在DQL中每个 Q Q Q函数都会使用另一个 Q Q Q函数的值更新下一个状态,而且两个 Q Q Q函数都必须从不同的经验集中学习,但是选择要执行的动作可以同时使用两个值函数。因此DQL并没有使计算量增加一倍,只是需要增加一倍的内存来存储另一个估计。对于用网络近似值函数的情况来说,就是多了一个网络。
划重点:DQL算法的数据效率不低于QL算法。
在实验中作者为每个动作计算了两个Q值的平均值,然后对所得的平均Q值进行了贪婪探索。算法伪代码如下:
此处对于QL算法和DQL算法来说,DQL使用了B网络来更新A网络,同样的道理对于B网络则使用A网络的值来更新。
说了这么多,可以给出DQL的更新公式了:
Q 1 ( S t , A t ) ← Q 1 ( S t , A t ) + α [ R t + 1 + γ Q 2 ( S t + 1 , arg max a Q 1 ( S t + 1 , a ) ) − Q 1 ( S t , A t ) ] Q_{1}\left(S_{t}, A_{t}\right) \leftarrow Q_{1}\left(S_{t}, A_{t}\right)+\alpha\left[R_{t+1}+\gamma Q_{2}\left(S_{t+1}, \arg \max _{a} Q_{1}\left(S_{t+1}, a\right)\right)-Q_{1}\left(S_{t}, A_{t}\right)\right] Q1(St,At)←Q1(St,At)+α[Rt+1+γQ2(St+1,argamaxQ1(St+1,a))−Q1(St,At)]
在每次更新Q表时,我们以0.5的概率使用上式更新 Q 1 Q_1 Q1 ,同样的也会有0.5的概率更新 Q 2 Q_2 Q2 , Q 2 Q_2 Q2的更新公式可以仿照上式给出。
DQL 代码实现
DQL的代码实现,其实在QL基础上稍加改动就可以了,主要的区别在于DQL使用了两个Q表格,并且选择动作时,是使用两个表格联合求得的,在更新时随机更新某一个Q表格。
import numpy as np
import matplotlib.pyplot as plt
import gym
import math
def moving_average(a, window_size):
"""滑动平均"""
cumulative_sum = np.cumsum(np.insert(a, 0, 0))
middle = (cumulative_sum[window_size:] - cumulative_sum[:-window_size]) / window_size
r = np.arange(1, window_size - 1, 2)
begin = np.cumsum(a[:window_size - 1])[::2] / r
end = (np.cumsum(a[:-window_size:-1])[::2] / r)[::-1]
return np.concatenate((begin, middle, end))
class DoubleQLearning:
def __init__(self):
self.env = gym.make('Pendulum-v0')
# np.random.seed(0)
# self.env.seed(0)
self.num_states = self.env.observation_space.shape[0]
self.gamma = 0.9 # decrease rate
self.lr = 0.1 # learning rate
self.max_steps = 200 # steps for 1 episode
self.num_episodes = 5000 # number of episodes
self.epsilon = 0.95
# uniform distributed sample with size
self.qA_table = np.random.uniform(low=-1, high=1, size=(30 * 20, 5)) * 2
self.qB_table = np.random.uniform(low=-1, high=1, size=(30 * 20, 5)) * 2
def bins(self, clip_min, clip_max, num):
"""分箱处理函数,把[clip_min,clip_max]区间平均分为num段,位于i段区间的特征值x会被离散化为i"""
return np.linspace(clip_min, clip_max, num + 1)[1:-1]
def digitize_state(self, observation):
"""get the discrete state in total 1296 states"""
cosTheta, sinTheta, thetaDot = observation
theta = math.acos(cosTheta)
if sinTheta < 0:
theta *= -1
# 分别对各个连续特征值进行离散化(分箱处理)
digitized = [np.digitize(theta, bins=self.bins(-math.pi, math.pi, 20)),
np.digitize(thetaDot, bins=self.bins(-8.0, 8.0, 30))]
return digitized[0] + 20 * digitized[1]
def select_action(self, observation, episode):
"""epsilon-greedy"""
state = self.digitize_state(observation)
epsilon = self.epsilon + (1 / (episode + 1))
# 使用两个Q表的均值来选择动作
if np.random.uniform(0, 1) < epsilon:
action = np.argmax((self.qA_table[state, :] + self.qB_table[state, :]) / 2) # 查表得到最佳行动
else:
action = np.random.randint(0, 4)
return action
def update(self, observation, action, reward, observation_next, done):
state = self.digitize_state(observation)
state_next = self.digitize_state(observation_next)
# randomly update either QA or QB
if np.random.rand() < 0.5: # updade QA
action_next_Q_values = self.qA_table[state_next, :]
if done:
target_Q = reward
else:
max_Q_action = np.random.choice(np.where(action_next_Q_values == action_next_Q_values.max())[0])
target_Q = reward + self.gamma * self.qB_table[state_next, max_Q_action]
self.qA_table[state, action] += self.lr * (target_Q - self.qA_table[state, action])
else: # updade QB
action_next_Q_values = self.qB_table[state_next, :]
if done:
target_Q = reward
else:
max_Q_action = np.random.choice(np.where(action_next_Q_values == action_next_Q_values.max())[0])
target_Q = reward + self.gamma * self.qA_table[state_next, max_Q_action]
self.qB_table[state, action] += self.lr * (target_Q - self.qB_table[state, action])
def run(self):
reward_ep = []
max_q_value_list = []
max_q_value = 0
for episode in range(self.num_episodes): # 1000 episodes
observation = self.env.reset() # initialize environment
total_reward = 0
for step in range(self.max_steps): # steps in one episode
action = self.select_action(observation, episode)
observation_next, reward, done, _ = self.env.step([action - 2])
self.update(observation, action, reward, observation_next, done)
observation = observation_next
total_reward += reward
if done:
reward_ep.append(total_reward)
print('{0} Episode: Total Reward: {1}'.format(episode, total_reward))
break
return reward_ep
if __name__ == '__main__':
dql = DoubleQLearning()
reward_ep = dql.run()
episodes_list = list(range(len(reward_ep)))
mv_return = moving_average(reward_ep, 5)
plt.figure()
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('Double Q-Learning on {}'.format("Pendulum-v0"))
plt.show()
代码运行结果如下:
\quad
\quad
参考:
- https://blog.csdn.net/gsww404/article/details/103413124
- https://www.jiqizhixin.com/graph/technologies/0d189dc7-7f80-4643-9ff4-74941694d7d4
- https://zhuanlan.zhihu.com/p/57445939
持续更新中…
边栏推荐
- AURIX Development Studio安装
- Remote sensing submission process
- DGC best practice: how to ensure that confidential data is not leaked when entering the lake?
- 02基于ZigBee的智能家居系统设计
- 01基于RFID的智能仓储管理系统设计
- Powerful chart component library scottplot
- 03基于ZigBee的城市道路除尘降温系统设计
- Second order edge detection Laplacian of Guassian Gaussian Laplacian operator
- Fisher线性判别分析Fisher Linear Distrimination
- 01 design of intelligent warehouse management system based on RFID
猜你喜欢
![[literature reading] mcunet: tiny deep learning on IOT devices](/img/67/21e5c6b7cf95073850be4c7c20520c.png)
[literature reading] mcunet: tiny deep learning on IOT devices

Prohibit smart Safari from playing automatically when opening a web page

gdb+vscode进行调试2——gdb断点相关

电解电容特性及应用要点

windows安装mysql和jdbc

02 design of smart home system based on ZigBee

Hands on deep learning -- linear neural network

S32K148EVB 关于ENET Loopback实验

SAE j1708/j1587 protocol details

YYDS!阿里技术官最新总结的分布式核心技术笔记已上线,堪称福音
随机推荐
Oozie 集成 Shell
HRNet
Saber's most powerful digital analog mixed signal simulation software
Set up sqoop environment
close 和 shutdown区别
工厂方法模式随记
中心极限定理
Aurix development studio installation
[cute new problem solving] sum of three numbers
Recursive and recursive learning notes
gdb+vscode进行调试6——gdb调试多线程命令札记
Fairness in Deep Learning: A Computational Perspective
Yolov5训练建议
工程编译那点事:Makefile和cmake(一)
The following packages have unmet dependencies: deepin. com. wechat:i386 : Depends: deepin-wine:i386
Hue oozie editor scheduling shell
Frustratingly Simple Few-Shot Object Detection
deep learning实验笔记
01 design of intelligent warehouse management system based on RFID
Opengauss Developer Day 2022 dongfangtong sincerely invites you to visit the "dongfangtong ecological tools sub forum"