20.3 GRU基本原理
20.3 GRU基本原理与实现
20.3.1 GRU 的基本概念
LSTM 存在很多变体,其中门控循环单元(Gated Recurrent Unit, GRU)是最常见的一种,也是目前比较流行的一种。GRU是由 Cho 等人在2014年提出的,它对LSTM做了一些简化:
GRU将LSTM原来的三个门简化成为两个:重置门 r_t (Reset Gate)和更新门 z_t (Update Gate)。
GRU不保留单元状态 c_t ,只保留隐藏状态 h_t 作为单元输出,这样就和传统RNN的结构保持一致。
重置门直接作用于前一时刻的隐藏状态 h_{t-1} 。
20.3.2 GRU的前向计算
GRU的单元结构
图20-7展示了GRU的单元结构。
图20-7 GRU单元结构图
GRU单元的前向计算公式如下:
更新门
z_t = \sigma(h_{t-1} \cdot W_z + x_t \cdot U_z)
\tag{1}
重置门
r_t = \sigma(h_{t-1} \cdot W_r + x_t \cdot U_r)
\tag{2}
候选隐藏状态
\tilde h_t = \tanh((r_t \circ h_{t-1}) \cdot W_h + x_t \cdot U_h)
\tag{3}
隐藏状态
h = (1 - z_t) \circ h_{t-1} + z_t \circ \tilde{h}_t
\tag{4}
GRU的原理浅析
从上面的公式可以看出,GRU通过更新们和重置门控制长期状态的遗忘和保留,以及当前输入信息的选择。更新门和重置门通过sigmoid 函数,将输入信息映射到[0,1] 区间,实现门控功能。
首先,上一时刻的状态h_{t-1} 通过重置门,加上当前时刻输入信息,共同构成当前时刻的即时状态\tilde{h}_t ,并通过\tanh 函数映射到[-1,1] 区间。
然后,通过更新门实现遗忘和记忆两个部分。从隐藏状态的公式可以看出,通过z_t 进行选择性的遗忘和记忆。(1-z_t) 和z_t 有联动关系,上一时刻信息遗忘的越多,当前信息记住的就越多,实现了LSTM中f_t 和i_t 的功能。
20.3.3 GRU的反向传播
学习了LSTM的反向传播的推导,GRU的推导就相对简单了。我们仍然以l 层t 时刻的GRU单元为例,推导反向传播过程。
同LSTM, 令:l 层t 时刻传入误差为\delta_{t}^l ,为下一时刻传入误差\delta_{h_t}^l 和上一层传入误差\delta_{x_t}^{l+1} 之和,简写为\delta_{t} 。
令:
z_{zt} = h_{t-1} \cdot W_z + x_t \cdot U_z
\tag{5}
z_{rt} = h_{t-1} \cdot W_r + x_t \cdot U_r
\tag{6}
z_{\tilde h_t} = (r_t \circ h_{t-1}) \cdot W_h + x_t \cdot U_h
\tag{7}
则:
\begin{aligned}
\delta_{z_{zt}} &= \frac{\partial{loss}}{\partial{h_t}} \cdot \frac{\partial{h_t}}{\partial{z_t}} \cdot \frac{\partial{z_t}}{\partial{z_{z_t}}} \\\\
&= \delta_t \cdot (-diag[h_{t-1}] + diag[\tilde h_t]) \cdot diag[z_t \circ (1-z_t)] \\\\
&= \delta_t \circ (\tilde h_t - h_{t-1}) \circ z_t \circ (1-z_t)
\end{aligned}
\tag{8}
\begin{aligned}
\delta_{z_{\tilde{h}t}} &= \frac{\partial{loss}}{\partial{h_t}} \cdot \frac{\partial{h_t}}{\partial{\tilde h_t}} \cdot \frac{\partial{\tilde h_t}}{\partial{z_{\tilde h_t}}} \\\\
&= \delta_t \cdot diag[z_t] \cdot diag[1-(\tilde h_t)^2] \\\\
&= \delta_t \circ z_t \circ (1-(\tilde h_t)^2)
\end{aligned}
\tag{9}
\begin{aligned}
\delta_{z_{rt}} &= \frac{\partial{loss}}{\partial{\tilde h_t}} \cdot \frac{\partial{\tilde h_t}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{r_t}} \cdot \frac{\partial{r_t}}{\partial{z_{r_t}}} \\\\
&= \delta_{z_{\tilde{h}t}} \cdot W_h^T \cdot diag[h_{t-1}] \cdot diag[r_t \circ (1-r_t)] \\\\
&= \delta_{z_{\tilde{h}t}} \cdot W_h^T \circ h_{t-1} \circ r_t \circ (1-r_t)
\end{aligned}
\tag{10}
由此可求出,t 时刻各个可学习参数的误差:
\begin{aligned}
d_{W_{h,t}} = \frac{\partial{loss}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{W_h}} = (r_t \circ h_{t-1})^{\top} \cdot \delta_{z_{\tilde{h}t}}
\end{aligned}
\tag{11}
\begin{aligned}
d_{U_{h,t}} = \frac{\partial{loss}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{U_h}} = x_t^{\top} \cdot \delta_{z_{\tilde{h}t}}
\end{aligned}
\tag{12}
\begin{aligned}
d_{W_{r,t}} = \frac{\partial{loss}}{\partial{z_{r_t}}} \cdot \frac{\partial{z_{r_t}}}{\partial{W_r}} = h_{t-1}^{\top} \cdot \delta_{z_{rt}}
\end{aligned}
\tag{13}
\begin{aligned}
d_{U_{r,t}} = \frac{\partial{loss}}{\partial{z_{r_t}}} \cdot \frac{\partial{z_{r_t}}}{\partial{U_r}} = x_t^{\top} \cdot \delta_{z_{rt}}
\end{aligned}
\tag{14}
\begin{aligned}
d_{W_{z,t}} = \frac{\partial{loss}}{\partial{z_{z_t}}} \cdot \frac{\partial{z_{z_t}}}{\partial{W_z}} = h_{t-1}^{\top} \cdot \delta_{z_{zt}}
\end{aligned}
\tag{15}
\begin{aligned}
d_{U_{z,t}} = \frac{\partial{loss}}{\partial{z_{z_t}}} \cdot \frac{\partial{z_{z_t}}}{\partial{U_z}} = x_t^{\top} \cdot \delta_{z_{zt}}
\end{aligned}
\tag{16}
可学习参数的最终误差为各个时刻误差之和,即:
d_{W_h} = \sum_{t=1}^{\tau} d_{W_{h,t}} = \sum_{t=1}^{\tau} (r_t \circ h_{t-1})^{\top} \cdot \delta_{z_{\tilde{h}t}}
\tag{17}
d_{U_h} = \sum_{t=1}^{\tau} d_{U_{h,t}} = \sum_{t=1}^{\tau} x_t^{\top} \cdot \delta_{z_{\tilde{h}t}}
\tag{18}
d_{W_r} = \sum_{t=1}^{\tau} d_{W_{r,t}} = \sum_{t=1}^{\tau} h_{t-1}^{\top} \cdot \delta_{z_{rt}}
\tag{19}
d_{U_r} = \sum_{t=1}^{\tau} d_{U_{r,t}} = \sum_{t=1}^{\tau} x_t^{\top} \cdot \delta_{z_{rt}}
\tag{20}
d_{W_z} = \sum_{t=1}^{\tau} d_{W_{z,t}} = \sum_{t=1}^{\tau} h_{t-1}^{\top} \cdot \delta_{z_{zt}}
\tag{21}
d_{U_z} = \sum_{t=1}^{\tau} d_{U_{z,t}} = \sum_{t=1}^{\tau} x_t^{\top} \cdot \delta_{z_{zt}}
\tag{22}
当前GRU cell分别向前一时刻(t-1 )和下一层(l-1 )传递误差,公式如下:
沿时间向前传递:
\begin{aligned}
\delta_{h_{t-1}} = \frac{\partial{loss}}{\partial{h_{t-1}}} &= \frac{\partial{loss}}{\partial{h_t}} \cdot \frac{\partial{h_t}}{\partial{h_{t-1}}} + \frac{\partial{loss}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{h_{t-1}}} \\\\
&+ \frac{\partial{loss}}{\partial{z_{rt}}} \cdot \frac{\partial{z_{rt}}}{\partial{h_{t-1}}} + \frac{\partial{loss}}{\partial{z_{zt}}} \cdot \frac{\partial{z_{zt}}}{\partial{h_{t-1}}} \\\\
&= \delta_{t} \circ (1-z_t) + \delta_{z_{\tilde{h}t}} \cdot W_h^{\top} \circ r_t \\\\
&+ \delta_{z_{rt}} \cdot W_r^{\top} + \delta_{z_{zt}} \cdot W_z^{\top}
\end{aligned}
\tag{23}
沿层次向下传递:
\begin{aligned}
\delta_{x_t} &= \frac{\partial{loss}}{\partial{x_t}} = \frac{\partial{loss}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{x_t}} \\\\
&+ \frac{\partial{loss}}{\partial{z_{r_t}}} \cdot \frac{\partial{z_{r_t}}}{\partial{x_t}} + \frac{\partial{loss}}{\partial{z_{z_t}}} \cdot \frac{\partial{z_{z_t}}}{\partial{x_t}} \\\\
&= \delta_{z_{\tilde{h}t}} \cdot U_h^{\top} + \delta_{z_{rt}} \cdot U_r^{\top} + \delta_{z_{zt}} \cdot U_z^{\top}
\end{aligned}
\tag{24}
以上,GRU反向传播公式推导完毕。
20.3.4 代码实现
本节进行了GRU网络单元前向计算和反向传播的实现。为了统一和简单,测试用例依然是二进制减法。
初始化
本案例实现了没有bias的GRU单元,只需初始化输入维度和隐层维度。
def __init__ ( self , input_size , hidden_size ):
self . input_size = input_size
self . hidden_size = hidden_size
前向计算
def forward ( self , x , h_p , W , U ):
self . get_params ( W , U )
self . x = x
self . z = Sigmoid () . forward ( np . dot ( h_p , self . wz ) + np . dot ( x , self . uz ))
self . r = Sigmoid () . forward ( np . dot ( h_p , self . wr ) + np . dot ( x , self . ur ))
self . n = Tanh () . forward ( np . dot (( self . r * h_p ), self . wn ) + np . dot ( x , self . un ))
self . h = ( 1 - self . z ) * h_p + self . z * self . n
def split_params ( self , w , size ):
s = []
for i in range ( 3 ):
s . append ( w [( i * size ):(( i + 1 ) * size )])
return s [ 0 ], s [ 1 ], s [ 2 ]
# Get shared parameters, and split them to fit 3 gates, in the order of z, r, \tilde{h} (n stands for \tilde{h} in code)
def get_params ( self , W , U ):
self . wz , self . wr , self . wn = self . split_params ( W , self . hidden_size )
self . uz , self . ur , self . un = self . split_params ( U , self . input_size )
反向传播
def backward ( self , h_p , in_grad ):
self . dzz = in_grad * ( self . n - h_p ) * self . z * ( 1 - self . z )
self . dzn = in_grad * self . z * ( 1 - self . n * self . n )
self . dzr = np . dot ( self . dzn , self . wn . T ) * h_p * self . r * ( 1 - self . r )
self . dwn = np . dot (( self . r * h_p ) . T , self . dzn )
self . dun = np . dot ( self . x . T , self . dzn )
self . dwr = np . dot ( h_p . T , self . dzr )
self . dur = np . dot ( self . x . T , self . dzr )
self . dwz = np . dot ( h_p . T , self . dzz )
self . duz = np . dot ( self . x . T , self . dzz )
self . merge_params ()
# pass to previous time step
self . dh = in_grad * ( 1 - self . z ) + np . dot ( self . dzn , self . wn . T ) * self . r + np . dot ( self . dzr , self . wr . T ) + np . dot ( self . dzz , self . wz . T )
# pass to previous layer
self . dx = np . dot ( self . dzn , self . un . T ) + np . dot ( self . dzr , self . ur . T ) + np . dot ( self . dzz , self . uz . T )
我们将所有拆分的参数merge到一起,便于更新梯度。
def merge_params ( self ):
self . dW = np . concatenate (( self . dwz , self . dwr , self . dwn ), axis = 0 )
self . dU = np . concatenate (( self . duz , self . dur , self . dun ), axis = 0 )
20.3.5 最终结果
图20-8展示了训练过程,以及loss和accuracy的曲线变化。
图20-8 loss和accuracy的曲线变化图
该模型在验证集上可得100%的正确率。网络正确性得到验证。
x1: [1, 1, 1, 0]
- x2: [1, 0, 0, 0]
------------------
true: [0, 1, 1, 0]
pred: [0, 1, 1, 0]
14 - 8 = 6
====================
x1: [1, 1, 0, 0]
- x2: [0, 0, 0, 0]
------------------
true: [1, 1, 0, 0]
pred: [1, 1, 0, 0]
12 - 0 = 12
====================
x1: [1, 0, 1, 0]
- x2: [0, 0, 0, 1]
------------------
true: [1, 0, 0, 1]
pred: [1, 0, 0, 1]
10 - 1 = 9
代码位置
ch20, Level2
思考和练习
仿照LSTM的实现方式,自己实现GRU单元的前向计算和反向传播。
仿照LSTM的实例,使用GRU进行训练,并比较效果。