文章目录
- RNN/LSTM/GRU
- 一、RNN
- 1、为何引入RNN?
- 2、RNN的基本结构
- 3、各种形式的RNN及其应用
- 4、RNN的缺陷
- 5、如何应对RNN的缺陷?
- 6、BPTT和BP的区别
- 二、LSTM
- 1、LSTM 简介
- 2、LSTM如何缓解梯度消失与梯度爆炸?
- 三、GRU
- 四、参考文献
RNN/LSTM/GRU
一、RNN
1、为何引入RNN?
循环神经网络(Recurrent Neural Network,RNN) 是用来建模序列化数据的一种主流深度学习模型。我们知道,传统的前馈神经网络一般的输入都是一个定长的向量,无法处理变长的序列信息,即使通过一些方法把序列处理成定长的向量,模型也很难捕捉序列中的长距离依赖关系。RNN则通过将神经元串行起来处理序列化的数据。由于每个神经元能用它的内部变量保存之前输入的序列信息,因此整个序列被浓缩成抽象的表示,并可以据此进行分类或生成新的序列1。
2、RNN的基本结构
RNN的朴素形式可分别由如下两幅图表示2:
其中 x 1 , x 2 , ⋯ , x T x_1,x_2,\cdots,x_T x1,x2,⋯,xT 是输入,每一个位置是一个实数向量; U U U、 V V V、 W W W 是权重矩阵,通常在模型初始化时随机生成,通过梯度下降进行优化; h t h_t ht 是位于隐藏层上的活性值,很多文献上也称为状态(State)或隐状态(Hidden State); p t p_t pt 表示第 t t t 个位置上的输出。
h
t
h_t
ht、
p
t
p_t
pt 可由下列公式得出(
b
b
b 是偏置项):
h
t
=
tanh
(
U
⋅
h
t
−
1
+
W
⋅
x
t
+
b
)
h_t=\tanh\left(U\cdot h_{t-1}+W\cdot x_t+b\right)
ht=tanh(U⋅ht−1+W⋅xt+b)
p
t
=
s
o
f
t
m
a
x
(
V
⋅
h
t
+
c
)
p_t=\mathrm{softmax}(V\cdot h_t+c)
pt=softmax(V⋅ht+c)
3、各种形式的RNN及其应用
(图片来自于cs231n)
模式 | 描述 | 应用领域 |
---|---|---|
One to One | 单个输入对应单个输出 | 图像分类、回归任务 |
One to Many | 单个输入生成序列输出 | 图像字幕生成、音乐生成 |
Many to One | 序列输入生成单个输出 | 情感分析、时间序列分类 |
Many to Many | 序列输入对应序列输出 | 机器翻译、语音识别 |
Many to Many(同步) | 同步序列输入输出 | 视频帧分类、实时语音处理 |
4、RNN的缺陷
RNN通过在所有时间步共享相同的权重,使得可以在不同时间步之间传递和积累信息,从而更好地捕捉序列数据中的长期依赖关系,但是缺点也很明显:在RNN的学习过程中,由于共享权重 W W W,导致随着时间步的增加,权重矩阵 W W W 不断连乘,最终产生梯度消失(即 ∂ L t ∂ h k \frac{\partial \mathcal{L}_{t}}{\partial \boldsymbol{h}_{k}} ∂hk∂Lt 消失, 1 ≤ k ≤ t 1 \le k\le t 1≤k≤t )和梯度爆炸,具体解释如下:
首先由RNN前向传播公式:
h
t
=
f
(
W
⋅
h
t
−
1
+
U
⋅
x
t
+
b
)
h_t=f(W\cdot h_{t-1}+U\cdot x_t+b)
ht=f(W⋅ht−1+U⋅xt+b)
其中
f
f
f 为激活函数。
在反向传播时(BPTT),损失函数
L
\mathcal{L}
L 对某一时间步长的梯度涉及到时间上所有的前置状态,因此梯度会被多个矩阵连乘表示为:
∂
L
∂
h
t
=
∂
L
∂
h
T
⋅
∏
k
=
t
T
−
1
A
k
\frac{\partial\mathcal{L}}{\partial h_t}=\frac{\partial\mathcal{L}}{\partial h_T}\cdot\prod_{k=t}^{T-1}A_k
∂ht∂L=∂hT∂L⋅k=t∏T−1Ak
其中
A
k
=
diag
(
f
′
(
h
k
)
)
⋅
W
A_k=\operatorname{diag}(f^{\prime}(h_k))\cdot W
Ak=diag(f′(hk))⋅W 。
显然若 W > 1 W>1 W>1,随着时间的增加,多个 W W W 连乘后结果会不断增大,最终导致梯度爆炸;
同理 W < 1 W<1 W<1,多个 W W W 连乘后结果会不断减小至趋于0,最终导致梯度消失。
而在CNN中,每一层的权重矩阵
W
W
W 是不同的,并且在初始化时它们是独立同分布的,因此最后可以相互抵消,不容易发生梯度爆炸或消失。
5、如何应对RNN的缺陷?
① 对于梯度爆炸,一般通过权重衰减(Weight Decay) 或梯度截断(Gradient Clipping) 来避免3。权重衰减,通过引入衰减系数来约束并避免权重矩阵元素过大,从而减少梯度连乘时的爆炸风险;梯度截断,直接将梯度大小进行限制以防止梯度爆炸,比如按值截断:在第
t
t
t 次迭代时,梯度为
g
t
g_t
gt ,给定一个区间
[
a
,
b
]
[a,b]
[a,b] ,如果一个参数的梯度小于
a
a
a 时,就将其设为
a
a
a ;如果大于
b
b
b 时,就将其设为
b
b
b,公式如下:
g
t
=
max
(
min
(
g
t
,
b
)
,
a
)
.
\mathbf{g}_t=\max(\min(\mathbf{g}_t,b),a).
gt=max(min(gt,b),a).
② 对于梯度消失,一个想法是改进激活函数,比如替换成 ReLU ,因为其右侧导数恒为 1 ,可以缓解梯度消失(不能杜绝,因为本质上是权重矩阵的问题)。缺点是不好解决梯度爆炸,从 RNN 的前向传播公式来看待这个问题,前向传播公式如下:
h
t
=
f
(
W
⋅
h
t
−
1
+
U
⋅
x
t
+
b
)
h_t=f(W\cdot h_{t-1}+U\cdot x_t+b)
ht=f(W⋅ht−1+U⋅xt+b)
使用 ReLU 激活函数后,
h
t
h_t
ht 可表达为:
h
t
=
r
e
l
u
(
W
⋅
h
t
−
1
+
U
⋅
x
t
+
b
)
h_t=\mathrm{relu}\left(W\cdot h_{t-1}+U\cdot x_t+b\right)
ht=relu(W⋅ht−1+U⋅xt+b)
显然不管
h
t
−
1
h_{t-1}
ht−1 怎么变化,前面始终要乘上一个权重矩阵
W
W
W ,因此替换激活函数后,并不能实质上解决由于权重矩阵
W
W
W 连乘而导致的梯度爆炸问题。
③ 使用合适的权重初始化方法,如 Xavier 初始化或 He 初始化,使 W W W 的特征值接近 1 。
如果从结构上来考虑,通过改变网络结构来减缓梯度消失或爆炸,长短期记忆网络(LSTM,Long Short-Term Memory) 就是基于这个想法诞生的。
6、BPTT和BP的区别
BP算法:只处理纵向层级间的梯度反向传播,适用于前馈神经网络。
BPTT算法:在训练RNN时,需要同时处理纵向层级间的反向传播(深度方向)和时间维度上的反向传播(时间方向)。
二、LSTM
1、LSTM 简介
LSTM 是循环神经网络的一个变体,可以有效地解决简单循环神经网络的梯度爆炸或消失问题。LSTM 网络结构如下:
在这里插入图片描述
LSTM 网络引入门控机制(Gating Mechanism) 来控制信息传递的路径,公式如下:
i
t
=
σ
(
U
i
⋅
h
t
−
1
+
W
i
⋅
x
t
+
b
i
)
f
t
=
σ
(
U
f
⋅
h
t
−
1
+
W
f
⋅
x
t
+
b
f
)
o
t
=
σ
(
U
o
⋅
h
t
−
1
+
W
o
⋅
x
t
+
b
o
)
c
~
t
=
tanh
(
U
c
⋅
h
t
−
1
+
W
c
⋅
x
t
+
b
c
)
c
t
=
i
t
⊙
c
~
t
+
f
t
⊙
c
t
−
1
h
t
=
o
t
⊙
tanh
(
c
t
)
\begin{array}{c}\boldsymbol{i}_{t}=\sigma\left(\boldsymbol{U}_{i} \cdot \boldsymbol{h}_{t-1}+\boldsymbol{W}_{i} \cdot \boldsymbol{x}_{t}+\boldsymbol{b}_{i}\right) \\\boldsymbol{f}_{t}=\sigma\left(\boldsymbol{U}_{f} \cdot \boldsymbol{h}_{t-1}+\boldsymbol{W}_{f} \cdot \boldsymbol{x}_{t}+\boldsymbol{b}_{f}\right) \\\boldsymbol{o}_{t}=\sigma\left(\boldsymbol{U}_{o} \cdot \boldsymbol{h}_{t-1}+\boldsymbol{W}_{o} \cdot \boldsymbol{x}_{t}+\boldsymbol{b}_{o}\right) \\\tilde{\boldsymbol{c}}_{t}=\tanh \left(\boldsymbol{U}_{c} \cdot \boldsymbol{h}_{t-1}+\boldsymbol{W}_{c} \cdot \boldsymbol{x}_{t}+\boldsymbol{b}_{c}\right) \\\boldsymbol{c}_{t}=\boldsymbol{i}_{t} \odot \tilde{\boldsymbol{c}}_{t}+\boldsymbol{f}_{t} \odot \boldsymbol{c}_{t-1} \\\boldsymbol{h}_{t}=\boldsymbol{o}_{\boldsymbol{t}} \odot \tanh \left(\boldsymbol{c}_{t}\right)\end{array}
it=σ(Ui⋅ht−1+Wi⋅xt+bi)ft=σ(Uf⋅ht−1+Wf⋅xt+bf)ot=σ(Uo⋅ht−1+Wo⋅xt+bo)c~t=tanh(Uc⋅ht−1+Wc⋅xt+bc)ct=it⊙c~t+ft⊙ct−1ht=ot⊙tanh(ct)
进一步可以简写成:
[
c
~
t
o
t
i
t
f
t
]
=
[
tanh
σ
σ
σ
]
(
W
[
x
t
h
t
−
1
]
+
b
)
,
c
t
=
f
t
⊙
c
t
−
1
+
i
t
⊙
c
~
t
,
h
t
=
o
t
⊙
tanh
(
c
t
)
,
\begin{aligned}\begin{bmatrix}\tilde{\boldsymbol{c}}_t\\\\\boldsymbol{o}_t\\\\\boldsymbol{i}_t\\\\\boldsymbol{f}_t\end{bmatrix}&=\begin{bmatrix}\tanh\\\\\sigma\\\\\sigma\\\\\sigma\end{bmatrix}\begin{pmatrix}\boldsymbol{W}\begin{bmatrix}\boldsymbol{x}_t\\\\\boldsymbol{h}_{t-1}\end{bmatrix}+\boldsymbol{b}\end{pmatrix},\\\\\boldsymbol{c}_t&=\boldsymbol{f}_t\odot\boldsymbol{c}_{t-1}+\boldsymbol{i}_t\odot\boldsymbol{\tilde{c}}_t,\\\boldsymbol{h}_t&=\boldsymbol{o}_t\odot\tanh\left(\boldsymbol{c}_t\right),\end{aligned}
c~totitft
ctht=
tanhσσσ
W
xtht−1
+b
,=ft⊙ct−1+it⊙c~t,=ot⊙tanh(ct),
公式中有三个“门”,分别为输入门 i t \boldsymbol{i}_t it 、遗忘门 f t \boldsymbol{f}_t ft 和输出门 o t \boldsymbol{o}_t ot 。这三个门的作用为
- 遗忘门 f t f_t ft 控制上一个时刻的内部状态 c t − 1 \boldsymbol c_t-1 ct−1 需要遗忘多少信息。
- 输入门 i t \boldsymbol{i}_t it 控制当前时刻的候选状态 c ~ t \tilde{\boldsymbol{c}}_t c~t 有多少信息需要保存。
- 输出门 o t \boldsymbol{o}_t ot 控制当前时刻的内部状态 c t \boldsymbol{c}_t ct 有多少信息需要输出给外部状态 h t . \boldsymbol{h}_t. ht.
具体的可点击查看如下视频,很清晰易懂:
【【官方双语】LSTM(长短期记忆神经网络)最简单清晰的解释来了!】 https://www.bilibili.com/video/BV1zD421N7nA/?share_source=copy_web&vd_source=199a3f4e3a9db6061e1523e94505165a
2、LSTM如何缓解梯度消失与梯度爆炸?
LSTM的细胞状态更新机制(下图黄色部分)可以有效地存储长期的信息:
其更新公式如下:
C
t
=
f
t
⊙
C
t
−
1
+
i
t
⊙
C
~
t
C_t=f_t\odot C_{t-1}+i_t\odot\tilde{C}_t
Ct=ft⊙Ct−1+it⊙C~t
由于这一过程本质是线性操作(加权求和),相当于是所有候选路径的线性组合,故不会因为一个路径上梯度的消失,而导致整体梯度不断衰减。LSTM的细胞状态经过门控机制(通过或阻断,即 1 或 0)控制这个线性组合,达到缓解梯度消失的效果;而门控机制又可以通过调节输入输出,通过灵活地舍弃一些部分,来缓解梯度爆炸问题。
简言之,由于此线性组合会通过门控机制自主的调节,而非 RNN 那样直接连乘,因此可以达到减缓梯度消失和梯度爆炸的效果,并实现对信息的过滤,从而达到对长期记忆的保存与控制。
三、GRU
门控循环单元(GRU) 是对 LSTM 进行简化得到的模型。对于 LSTM 与 GRU 而言,它们效果相当,但由于 GRU 参数更少,所以 GRU 的收敛速度更快,计算效率更高。
与LSTM相比,GRU 仅有两个门——更新门(update gate)和重置门(reset gate),不使用记忆元。重置门有助于捕获序列中的短期依赖关系,更新门有助于捕获序列中的长期依赖关系,详细结构如下图: