# 20.3 GRU基本原理

## 20.3 GRU基本原理与实现⚓︎

### 20.3.1 GRU 的基本概念⚓︎

LSTM 存在很多变体，其中门控循环单元（Gated Recurrent Unit, GRU）是最常见的一种，也是目前比较流行的一种。GRU是由 Cho 等人在2014年提出的，它对LSTM做了一些简化：

1. GRU将LSTM原来的三个门简化成为两个：重置门 $r_t$（Reset Gate）和更新门 $z_t$ (Update Gate)。
2. GRU不保留单元状态 $c_t$，只保留隐藏状态 $h_t$作为单元输出，这样就和传统RNN的结构保持一致。
3. 重置门直接作用于前一时刻的隐藏状态 $h_{t-1}$

### 20.3.2 GRU的前向计算⚓︎

#### GRU的单元结构⚓︎

GRU单元的前向计算公式如下：

1. 更新门

z_t = \sigma(h_{t-1} \cdot W_z + x_t \cdot U_z) \tag{1}
2. 重置门

r_t = \sigma(h_{t-1} \cdot W_r + x_t \cdot U_r) \tag{2}
3. 候选隐藏状态

\tilde h_t = \tanh((r_t \circ h_{t-1}) \cdot W_h + x_t \cdot U_h) \tag{3}
4. 隐藏状态

h = (1 - z_t) \circ h_{t-1} + z_t \circ \tilde{h}_t \tag{4}

### 20.3.3 GRU的反向传播⚓︎

z_{zt} = h_{t-1} \cdot W_z + x_t \cdot U_z \tag{5}
z_{rt} = h_{t-1} \cdot W_r + x_t \cdot U_r \tag{6}
z_{\tilde h_t} = (r_t \circ h_{t-1}) \cdot W_h + x_t \cdot U_h \tag{7}

\begin{aligned} \delta_{z_{zt}} &= \frac{\partial{loss}}{\partial{h_t}} \cdot \frac{\partial{h_t}}{\partial{z_t}} \cdot \frac{\partial{z_t}}{\partial{z_{z_t}}} \\\\ &= \delta_t \cdot (-diag[h_{t-1}] + diag[\tilde h_t]) \cdot diag[z_t \circ (1-z_t)] \\\\ &= \delta_t \circ (\tilde h_t - h_{t-1}) \circ z_t \circ (1-z_t) \end{aligned} \tag{8}
\begin{aligned} \delta_{z_{\tilde{h}t}} &= \frac{\partial{loss}}{\partial{h_t}} \cdot \frac{\partial{h_t}}{\partial{\tilde h_t}} \cdot \frac{\partial{\tilde h_t}}{\partial{z_{\tilde h_t}}} \\\\ &= \delta_t \cdot diag[z_t] \cdot diag[1-(\tilde h_t)^2] \\\\ &= \delta_t \circ z_t \circ (1-(\tilde h_t)^2) \end{aligned} \tag{9}
\begin{aligned} \delta_{z_{rt}} &= \frac{\partial{loss}}{\partial{\tilde h_t}} \cdot \frac{\partial{\tilde h_t}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{r_t}} \cdot \frac{\partial{r_t}}{\partial{z_{r_t}}} \\\\ &= \delta_{z_{\tilde{h}t}} \cdot W_h^T \cdot diag[h_{t-1}] \cdot diag[r_t \circ (1-r_t)] \\\\ &= \delta_{z_{\tilde{h}t}} \cdot W_h^T \circ h_{t-1} \circ r_t \circ (1-r_t) \end{aligned} \tag{10}

\begin{aligned} d_{W_{h,t}} = \frac{\partial{loss}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{W_h}} = (r_t \circ h_{t-1})^{\top} \cdot \delta_{z_{\tilde{h}t}} \end{aligned} \tag{11}
\begin{aligned} d_{U_{h,t}} = \frac{\partial{loss}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{U_h}} = x_t^{\top} \cdot \delta_{z_{\tilde{h}t}} \end{aligned} \tag{12}
\begin{aligned} d_{W_{r,t}} = \frac{\partial{loss}}{\partial{z_{r_t}}} \cdot \frac{\partial{z_{r_t}}}{\partial{W_r}} = h_{t-1}^{\top} \cdot \delta_{z_{rt}} \end{aligned} \tag{13}
\begin{aligned} d_{U_{r,t}} = \frac{\partial{loss}}{\partial{z_{r_t}}} \cdot \frac{\partial{z_{r_t}}}{\partial{U_r}} = x_t^{\top} \cdot \delta_{z_{rt}} \end{aligned} \tag{14}
\begin{aligned} d_{W_{z,t}} = \frac{\partial{loss}}{\partial{z_{z_t}}} \cdot \frac{\partial{z_{z_t}}}{\partial{W_z}} = h_{t-1}^{\top} \cdot \delta_{z_{zt}} \end{aligned} \tag{15}
\begin{aligned} d_{U_{z,t}} = \frac{\partial{loss}}{\partial{z_{z_t}}} \cdot \frac{\partial{z_{z_t}}}{\partial{U_z}} = x_t^{\top} \cdot \delta_{z_{zt}} \end{aligned} \tag{16}

d_{W_h} = \sum_{t=1}^{\tau} d_{W_{h,t}} = \sum_{t=1}^{\tau} (r_t \circ h_{t-1})^{\top} \cdot \delta_{z_{\tilde{h}t}} \tag{17}
d_{U_h} = \sum_{t=1}^{\tau} d_{U_{h,t}} = \sum_{t=1}^{\tau} x_t^{\top} \cdot \delta_{z_{\tilde{h}t}} \tag{18}
d_{W_r} = \sum_{t=1}^{\tau} d_{W_{r,t}} = \sum_{t=1}^{\tau} h_{t-1}^{\top} \cdot \delta_{z_{rt}} \tag{19}
d_{U_r} = \sum_{t=1}^{\tau} d_{U_{r,t}} = \sum_{t=1}^{\tau} x_t^{\top} \cdot \delta_{z_{rt}} \tag{20}
d_{W_z} = \sum_{t=1}^{\tau} d_{W_{z,t}} = \sum_{t=1}^{\tau} h_{t-1}^{\top} \cdot \delta_{z_{zt}} \tag{21}
d_{U_z} = \sum_{t=1}^{\tau} d_{U_{z,t}} = \sum_{t=1}^{\tau} x_t^{\top} \cdot \delta_{z_{zt}} \tag{22}

\begin{aligned} \delta_{h_{t-1}} = \frac{\partial{loss}}{\partial{h_{t-1}}} &= \frac{\partial{loss}}{\partial{h_t}} \cdot \frac{\partial{h_t}}{\partial{h_{t-1}}} + \frac{\partial{loss}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{h_{t-1}}} \\\\ &+ \frac{\partial{loss}}{\partial{z_{rt}}} \cdot \frac{\partial{z_{rt}}}{\partial{h_{t-1}}} + \frac{\partial{loss}}{\partial{z_{zt}}} \cdot \frac{\partial{z_{zt}}}{\partial{h_{t-1}}} \\\\ &= \delta_{t} \circ (1-z_t) + \delta_{z_{\tilde{h}t}} \cdot W_h^{\top} \circ r_t \\\\ &+ \delta_{z_{rt}} \cdot W_r^{\top} + \delta_{z_{zt}} \cdot W_z^{\top} \end{aligned} \tag{23}

\begin{aligned} \delta_{x_t} &= \frac{\partial{loss}}{\partial{x_t}} = \frac{\partial{loss}}{\partial{z_{\tilde h_t}}} \cdot \frac{\partial{z_{\tilde h_t}}}{\partial{x_t}} \\\\ &+ \frac{\partial{loss}}{\partial{z_{r_t}}} \cdot \frac{\partial{z_{r_t}}}{\partial{x_t}} + \frac{\partial{loss}}{\partial{z_{z_t}}} \cdot \frac{\partial{z_{z_t}}}{\partial{x_t}} \\\\ &= \delta_{z_{\tilde{h}t}} \cdot U_h^{\top} + \delta_{z_{rt}} \cdot U_r^{\top} + \delta_{z_{zt}} \cdot U_z^{\top} \end{aligned} \tag{24}

### 20.3.4 代码实现⚓︎

#### 初始化⚓︎

def __init__(self, input_size, hidden_size):
self.input_size = input_size
self.hidden_size = hidden_size


#### 前向计算⚓︎

def forward(self, x, h_p, W, U):
self.get_params(W, U)
self.x = x

self.z = Sigmoid().forward(np.dot(h_p, self.wz) + np.dot(x, self.uz))
self.r = Sigmoid().forward(np.dot(h_p, self.wr) + np.dot(x, self.ur))
self.n = Tanh().forward(np.dot((self.r * h_p), self.wn) + np.dot(x, self.un))
self.h = (1 - self.z) * h_p + self.z * self.n

def split_params(self, w, size):
s=[]
for i in range(3):
s.append(w[(i*size):((i+1)*size)])
return s[0], s[1], s[2]

# Get shared parameters, and split them to fit 3 gates, in the order of z, r, \tilde{h} (n stands for \tilde{h} in code)
def get_params(self, W, U):
self.wz, self.wr, self.wn = self.split_params(W, self.hidden_size)
self.uz, self.ur, self.un = self.split_params(U, self.input_size)


#### 反向传播⚓︎

def backward(self, h_p, in_grad):
self.dzz = in_grad * (self.n - h_p) * self.z * (1 - self.z)
self.dzn = in_grad * self.z * (1 - self.n * self.n)
self.dzr = np.dot(self.dzn, self.wn.T) * h_p * self.r * (1 - self.r)

self.dwn = np.dot((self.r * h_p).T, self.dzn)
self.dun = np.dot(self.x.T, self.dzn)
self.dwr = np.dot(h_p.T, self.dzr)
self.dur = np.dot(self.x.T, self.dzr)
self.dwz = np.dot(h_p.T, self.dzz)
self.duz = np.dot(self.x.T, self.dzz)

self.merge_params()

# pass to previous time step
self.dh = in_grad * (1 - self.z) + np.dot(self.dzn, self.wn.T) * self.r + np.dot(self.dzr, self.wr.T) + np.dot(self.dzz, self.wz.T)
# pass to previous layer
self.dx = np.dot(self.dzn, self.un.T) + np.dot(self.dzr, self.ur.T) + np.dot(self.dzz, self.uz.T)


def merge_params(self):
self.dW = np.concatenate((self.dwz, self.dwr, self.dwn), axis=0)
self.dU = np.concatenate((self.duz, self.dur, self.dun), axis=0)


### 20.3.5 最终结果⚓︎

  x1: [1, 1, 1, 0]
- x2: [1, 0, 0, 0]
------------------
true: [0, 1, 1, 0]
pred: [0, 1, 1, 0]
14 - 8 = 6
====================
x1: [1, 1, 0, 0]
- x2: [0, 0, 0, 0]
------------------
true: [1, 1, 0, 0]
pred: [1, 1, 0, 0]
12 - 0 = 12
====================
x1: [1, 0, 1, 0]
- x2: [0, 0, 0, 1]
------------------
true: [1, 0, 0, 1]
pred: [1, 0, 0, 1]
10 - 1 = 9


ch20, Level2

### 思考和练习⚓︎

1. 仿照LSTM的实现方式，自己实现GRU单元的前向计算和反向传播。
2. 仿照LSTM的实例，使用GRU进行训练，并比较效果。