跳转至

20.1 LSTM基本原理

20.1 LSTM基本原理⚓︎

本小节中,我们将学习长短时记忆(Long Short Term Memory, LSTM)网络的基本原理。

20.1.1 提出问题⚓︎

循环神经网络(RNN)的提出,使神经网络可以训练和解决带有时序信息的任务,大大拓宽了神经网络的使用范围。但是原始的RNN有明显的缺陷。不管是双向RNN,还是深度RNN,都有一个严重的缺陷:训练过程中经常会出现梯度爆炸和梯度消失的问题,以至于原始的RNN很难处理长距离的依赖。

从实例角度⚓︎

例如,在语言生成问题中:

佳佳今天帮助妈妈洗碗,帮助爸爸修理椅子,还帮助爷爷奶奶照顾小狗毛毛,大家都夸奖了

例句中出现了很多人,空白出要填谁呢?我们知道是“佳佳”,传统RNN无法很好学习这么远距离的依赖关系。

从理论角度⚓︎

根据循环神经网络的反向传播算法,可以得到任意时刻k, 误差项沿时间反向传播的公式如下: $$ \deltaT_k=\deltaT_t \prod_{i=k}^{t-1} diag[f'(z_i)]W $$

其中 f为激活函数,zi为神经网络在第i时刻的加权输入, W为权重矩阵,diag表示一个对角矩阵。

注意,由于使用链式求导法则,式中有一个连乘项 i=kt1diag[f(zi)]W , 如果激活函数是挤压型,例如 Tanhsigmoid , 他们的导数值在 [0,1] 之间。我们再来看W。 1. 如果W的值在 (0,1) 的范围内, 则随着t的增大,连乘项会越来越趋近于0, 误差无法传播,这就导致了 梯度消失 的问题。 2. 如果W的值很大,使得diag[f(zi)]W的值大于1, 则随着t的增大,连乘项的值会呈指数增长,并趋向于无穷,产生 梯度爆炸

梯度消失使得误差无法传递到较早的时刻,权重无法更新,网络停止学习。梯度爆炸又会使网络不稳定,梯度过大,权重变化太大,无法很好学习,最坏情况还会产生溢出(NaN)错误而无法更新权重。

解决办法⚓︎

为了解决这个问题,科学家们想了很多办法。

  1. 采用半线性激活函数ReLU代替 挤压型激活函数,ReLU函数在定义域大于0的部分,导数恒等于1,来解决梯度消失问题。
  2. 合理初始化权重W,使diag[f(zi)]W的值尽量趋近于1,避免梯度消失和梯度爆炸。

上面两种办法都有一定的缺陷,ReLU函数有自身的缺点,而初始化权重的策略也抵不过连乘操作带来的指数增长问题。要想根本解决问题,必须去掉连乘项。

科学家们冥思苦想,终于提出了新的模型 —— 长短时记忆网络(LSTM)。

20.1.2 LSTM网络⚓︎

20.1.2.1 LSTM的结构⚓︎

LSTM 的设计思路比较简单,原来的RNN中隐藏层只有一个状态h,对短期输入敏感,现在再增加一个状态c,来保存长期状态。这个新增状态称为 细胞状态(cell state)**或**单元状态。 增加细胞状态前后的网络对比,如图20-1,20-2所示。

图20-1 传统RNN结构示意图

图20-2 LSTM结构示意图

那么,如何控制长期状态c呢?在任意时刻t,我们需要确定三件事:

  1. t1时刻传入的状态ct1,有多少需要保留。
  2. 当前时刻的输入信息,有多少需要传递到t+1时刻。
  3. 当前时刻的隐层输出ht是什么。

LSTM设计了 门控(gate) 结构,控制信息的保留和丢弃。LSTM有三个门,分别是:遗忘门(forget gate),输入门(input gate)和输出门(output gate)。

图20-3是常见的LSTM结构,我们以任意时刻t的一个LSTM单元(LSTM cell)为例,来分析其工作原理。

图20-3 LSTM内部结构意图

20.1.2.2 LSTM的前向计算⚓︎

  1. 遗忘门

    由上图可知,遗忘门的输出为ft, 采用sigmoid激活函数,将输出映射到[0,1]区间。上一时刻细胞状态ct1通过遗忘门时,与ft结果相乘,显然,乘数为0的信息被全部丢弃,为1的被全部保留。这样就决定了上一细胞状态ct1有多少能进入当前状态ct

    遗忘门ft的公式如下:

    (1)ft=σ(ht1Wf+xtUf+bf)

    其中,σ为sigmoid激活函数,ht1 为上一时刻的隐层状态,形状为(1×h)的行向量。xt为当前时刻的输入,形状为(1×i)的行向量。参数矩阵WfUf分别是(h×h)(i×h)的矩阵,bf(1×h)的行向量。

    很多教科书或网络资料将公式写成如下格式:

    (1')ft=σ(Wf[ht1,xt]+bf)

    (1'')ft=σ([WfhWfx][ht1xt]+bf)=σ(Wfhht1+Wfxxt+bf)

    后两种形式将权重矩阵放在状态向量前面,在讲解原理时,与公式(1)没有区别,但在代码实现时会出现一些问题,所以,在本章中我们采用公式(1)的表达方式。

  2. 输入门

    输入门it决定输入信息有哪些被保留,输入信息包含当前时刻输入和上一时刻隐层输出两部分,存入即时细胞状态c~t中。输入门依然采用sigmoid激活函数,将输出映射到[0,1]区间。c~t通过输入门时进行信息过滤。

    输入门it的公式如下:

    (2)it=σ(ht1Wi+xtUi+bi)

    即时细胞状态 c~t的公式如下:

    (3)c~t=tanh(ht1Wc+xtUc+bc)

    上一时刻保留的信息,加上当前输入保留的信息,构成了当前时刻的细胞状态ct

    当前细胞状态ct的公式如下:

    (4)ct=ftct1+itc~t

    其中,符号 表示矩阵乘积, 表示 Hadamard 乘积,即元素乘积。

  3. 输出门

    最后,需要确定输出信息。

    输出门ot决定 ht1xt 中哪些信息将被输出,公式如下:

    (5)ot=σ(ht1Wo+xtUo+bo)

    细胞状态ct通过tanh激活函数压缩到 (-1, 1) 区间,通过输出门,得到当前时刻的隐藏状态ht作为输出,公式如下:

    $$ h_t=o_t \circ \tanh(c_t) \tag{6} $$s

最后,时刻t的预测输出为:

(7)at=σ(htV+b)

其中,

(8)zt=htV+b

经过上面的步骤,LSTM就完成了当前时刻的前向计算工作。

20.1.2.3 LSTM的反向传播⚓︎

LSTM使用时序反向传播算法(Backpropagation Through Time, BPTT)进行计算。图20-4是带有一个输出的LSTM cell。我们使用该图来推导反向传播过程。

图20-4 带有一个输出的LSTM单元

假设当前LSTM cell处于第l层、t时刻。那么,它从两个方向接受反向传播的误差:一个是从t+1时刻l层传回的误差,记为δhtl(注意,这里的下标不是ht+1,而是ht);另一个是从t时刻l+1层的输入传回误差,记为 δxtl+1

我们先复习几个在推导过程中会使用到的激活函数,以及其导数公式。令sigmoid = σ,则:

(9)σ(z)=y=11+ez
(10)σ(z)=y(1y)
(11)tanh(z)=y=ezezez+ez
(12)tanh(z)=1y2

假设某一线性函数 zi 经过Softmax函数之后的预测输出为 y^i,该输出的标签值为 yi,则:

(13)softmax(zi)=y^i=ezij=1mezj
(14)losszi=y^iyi

从图中可知,从上层传回的误差为输出层zthtl传回的误差,假设输出层的激活函数为softmax函数,输出层标签值为y,则:

(15)δxtl+1=lossztzthtl=(ay)V

t+1时刻传回的误差为δhtl,若t为时序的最后一个时间点,则δhtl=0

该cell的隐层htl的最终误差为两项误差之和,即:

(16)δtl=lossht=δhtl+δxtl+1=(ay)V

接下来的推导过程仅与本层相关,为了方便推导,我们忽略层次信息,令δtl=δt

可以求得各个门结构加权输入的误差,如下:

(17)δzot=losszot=losshthtototzot=δtdiag[tanh(ct)]diag[ot(1ot)]=δttanh(ct)ot(1ot)
(18)δct=lossct=losshthttanh(ct)tanh(ct)ct=δtdiag[ot]diag[1tanh2(ct)]=δtot(1tanh2(ct))
(19)δzc~t=losszc~t=lossctctc~tc~tzc~t=δctdiag[it]diag[1(c~t)2]=δctit(1(c~t)2)
(20)δzit=losszit=lossctctititzit=δctdiag[c~t]diag[it(1it)]=δctc~tit(1it)
(21)δzft=losszft=lossctctftftzft=δctdiag[ct1]diag[ft(1ft)]=δctct1ft(1ft)

于是,在t时刻,输出层参数的各项误差为:

(22)dWo,t=lossWo,t=losszotzotWo=ht1δzot
(23)dUo,t=lossUo,t=losszotzotUo=xtδzot
(24)dbo,t=lossbo,t=losszotzotbo=δzot

最终误差为各时刻误差之和,则:

(25)dWo=t=1τdWo,t=t=1τht1δzot
(26)dUo=t=1τdUo,t=t=1τxtδzot
(27)dbo=t=1τdbo,t=t=1τδzot

同理可得:

(28)dWc=t=1τdWc,t=t=1τht1δzc~t
(29)dUc=t=1τdUc,t=t=1τxtδzc~t
(30)dbc=t=1τdbc,t=t=1τδzc~t
(31)dWi=t=1τdWi,t=t=1τht1δzit
(32)dUi=t=1τdUi,t=t=1τxtδzit
(33)dbi=t=1τdbi,t=t=1τδzit
(34)dWf=t=1τdWf,t=t=1τht1δzft
(35)dUf=t=1τdUf,t=t=1τxtδzft
(36)dbf=t=1τdbf,t=t=1τδzft

当前LSTM cell分别向前一时刻(t1)和下一层(l1)传递误差,公式如下:

沿时间向前传递:

(37)δht1=lossht1=losszftzftht1+losszitzitht1+losszc~tzc~tht1+losszotzotht1=δzftWf+δzitWi+δzc~tWc+δzotWo

沿层次向下传递:

(38)δxt=lossxt=losszftzftxt+losszitzitxt+losszc~tzc~txt+losszotzotxt=δzftUf+δzitUi+δzc~tUc+δzotUo

以上,LSTM反向传播公式推导完毕。