RNN及其变体
00 min
2024-4-23
RNN结构可以很好利用序列之间的关系, 因此针对自然界具有连续性的输入序列, 如人类的语言, 语音等进行很好的处理, 广泛应用于NLP领域的各项任务, 如文本分类, 情感分析, 意图识别, 机器翻译等

1. 认识RNN模型

RNN(Recurrent Neural Network), 中文称作循环神经网络, 它一般以序列数据为输入, 通过网络内部的结构设计有效捕捉序列之间的关系特征, 一般也是以序列形式进行输出.下图为RNN单层网络结构:
notion image
上图中,RNN的输入有两个,一个是当前时间步输入,另一个是上一时间步的隐藏层(hidden)输出,输出有两个,为当前时间步的输出和隐藏层可以看出,RNN的循环机制使模型上一时间步产生的结果, 能够作为当下时间步输入的一部分(当下时间步的输入除了正常的输入外还包括上一步的隐层输出)对当下时间步的输出产生影响。

1.1 RNN模型分类

在这里我们将从两个角度对RNN模型进行分类. 第一个角度是输入和输出的结构, 第二个角度是RNN的内部构造.

1.1.1 输入与输出分类

  • N vs N - RNN
    • RNN最基础形式,特点为输入输出等长,因这一限制,使用范围较小,可用于生成等长度诗句
      • notion image
  • N vs 1 - RNN
    • 输入一个序列,输出为一个值(非序列),多用于文本分类任务。
    • 需要在最后一个隐藏层输出h上进行线性变换,(大部分情况下)然后为了更好的明确结果,还要使用sigmoid(二分类)或者softmax(多分类)进行处理
      • notion image
  • 1 vs N - RNN
    • 输入是值,输出为序列,即输入作用于每次的输出之上。这种结构可用于将文本生成文字任务。
      • notion image
  • N vs M - RNN
    • 不限输入输出长度的RNN结构。它由编码器和解码器两部分组成, 两者的内部结构都是某类RNN, 它也被称为seq2seq架构. 输入数据首先通过编码器, 最终输出一个隐含变量c, 之后最常用的做法是使用这个隐含变量c作用在解码器进行解码的每一步上, 以保证输入信息被有效利用.
    • seq2seq架构最早被提出应用于机器翻译, 因为其输入输出不受限制,如今也是应用最广的RNN模型结构. 在机器翻译, 阅读理解, 文本摘要等众多领域都进行了非常多的应用实践.
      • notion image

1.1.2 RNN内部结构分类

  • 传统RNN
  • LSTM
  • Bi-LSTM
  • GRU
  • Bi-GRU
在后面的内容,我们将对传统RNN、LSTM和GRU进行详细讲解

2. 传统RNN

2.1 RNN结构分析

notion image
我们把目光集中在中间的方块部分, 它的输入有两部分, 分别是h(t-1)以及x(t), 代表上一时间步的隐层输出, 以及此时间步的输入, 它们进入RNN结构体后, 会"融合"到一起, 这种融合我们根据结构解释可知, 是将二者进行拼接, 形成新的张量[x(t), h(t-1)], 之后这个新的张量将通过一个全连接层(线性层), 该层使用tanh作为激活函数, 最终得到该时间步的输出h(t), 它将作为下一个时间步的输入和x(t+1)一起进入结构体. 以此类推.
  • 根据结构分析得出内部计算公式:
  • 激活函数tanh的作用:用于帮助调节流经网络的值, tanh函数将值压缩在-1和1之间.

2.2 形状的变化

  • nn.RNN(input_size, hidden_size, num_layer)
    • input_size:输入样本词向量维度,即一个词用多少个0/1进行表示
    • hidden_size:隐藏层(输出层)的维度, 隐藏层神经元的个数
    • num_layer:网络层数,默认为1,可省,一般也设为1
    • nn.RNN()中,参数nonlinearity: 激活函数的选择, 默认是tanh.
  • input:input为传入的训练样本形状,为(sequence_length, batch_size, input_size)
    • sequence_length: 单个样本的输入序列的长度
    • batch_size: 批次的样本数
  • h0、hn:(num_layer * num_directions, batch_size, hidden_size)
    • RNN中,需要初始化h0, 一般为给定形状的全零张量
    • num_layer *num_directions:网络方向
  • output:(sequence_length, batch_size, hidden_size)
注:对于三维张量h,形状为(1, 3, 4), 那么h[-1]和h[0]都表示的是形状为(3, 4)的二维张量

2.3 代码演示

在torch.nn工具包之中, 通过torch.nn.RNN可调用

2.4 传统RNN优缺点

1、优缺点:

  • 优点:
    • 由于内部结构简单, 对计算资源要求低, 相比之后我们要学习的RNN变体:LSTM和GRU模型参数总量少了很多, 在短序列任务上性能和效果都表现优异.
  • 缺点
    • 传统RNN在解决长序列之间的关联时, 通过实践,证明经典RNN表现很差, 原因是在进行反向传播的时候, 过长的序列导致梯度的计算异常, 发生梯度消失或爆炸.

2、梯度消失或爆炸介绍

根据反向传播算法和链式法则, 梯度的计算可以简化为以下公式:
notion image
  • 其中sigmoid的导数值域是固定的, 在[0, 0.25]之间, 而一旦公式中的w也小于1, 那么通过这样的公式连乘后, 最终的梯度就会变得非常非常小, 这种现象称作梯度消失. 反之, 如果我们人为的增大w的值, 使其大于1, 那么连乘够就可能造成梯度过大, 称作梯度爆炸.
  • 梯度消失或爆炸的危害:
    • 如果在训练过程中发生了梯度消失,权重无法被更新,最终导致训练失败; 梯度爆炸所带来的梯度过大,大幅度更新网络参数,在极端情况下,结果会溢出(NaN值).

3. LSTM和Bi-LSTM

LSTM(Long Short-Term Memory)也称长短时记忆结构, 它是传统RNN的变体, 与经典RNN相比能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象. 同时LSTM的结构更复杂, 它的核心结构可以分为四个部分去解析:遗忘门、输入门、细胞状态、输出门
notion image

3.1 内部结构解析

3.1.1 遗忘门

notion image
与传统RNN的内部结构计算非常相似, 首先将当前时间步输入x(t)与上一个时间步隐含状态h(t-1)拼接, 得到[x(t), h(t-1)], 然后通过一个全连接层做变换, 最后通过sigmoid函数进行激活得到f(t), 我们可以将f(t)看作是门值, 好比一扇门开合的大小程度, 门值都将作用在通过该扇门的张量, 遗忘门门值将作用的上一层的细胞状态上, 代表遗忘过去的多少信息, 又因为遗忘门门值是由x(t), h(t-1)计算得来的, 因此整个公式意味着根据当前时间步输入和上一个时间步隐含状态h(t-1)来决定遗忘多少上一层的细胞状态所携带的过往信息.

3.1.2 输入门

notion image
我们看到输入门的计算公式有两个, 第一个就是产生输入门门值的公式, 它和遗忘门公式几乎相同, 区别只是在于它们之后要作用的目标上. 这个公式意味着输入信息有多少需要进行过滤. 输入门的第二个公式是与传统RNN的内部结构计算相同. 对于LSTM来讲, 它得到的是当前的细胞状态, 而不是像经典RNN一样得到的是隐含状态.

3.1.3 细胞状态

notion image
细胞更新的结构与计算公式非常容易理解, 这里没有全连接层, 只是将刚刚得到的遗忘门门值与上一个时间步得到的C(t-1)相乘, 再加上输入门门值与当前时间步得到的未更新C(t)相乘的结果. 最终得到更新后的C(t)作为下一个时间步输入的一部分. 整个细胞状态更新过程就是对遗忘门和输入门的应用.

3.1.4 输出门

notion image
输出门部分的公式也是两个, 第一个即是计算输出门的门值, 它和遗忘门,输入门计算方式相同. 第二个即是使用这个门值产生隐含状态h(t), 他将作用在更新后的细胞状态C(t)上, 并做tanh激活, 最终得到h(t)作为下一时间步输入的一部分. 整个输出门的过程, 就是为了产生隐含状态h(t).

3.2 Bi-LSTM介绍

Bi-LSTM即双向LSTM, 它没有改变LSTM本身任何的内部结构, 只是将LSTM应用两次且方向不同, 再将两次得到的LSTM结果进行拼接作为最终输出.
notion image
我们看到图中对"我爱中国"这句话或者叫这个输入序列, 进行了从左到右和从右到左两次LSTM处理, 将得到的结果张量进行了拼接作为最终输出. 这种结构能够捕捉语言语法中一些特定的前置或后置特征, 增强语义关联,但是模型参数和计算复杂度也随之增加了一倍, 一般需要对语料和计算资源进行评估后决定是否使用该结构.

3.3 构建LSTM模型

3.4 LSTM优缺点

  • 优势:
    • LSTM的门结构能够有效减缓长序列问题中可能出现的梯度消失或爆炸, 虽然并不能杜绝这种现象, 但在更长的序列问题上表现优于传统RNN.
  • 缺点:
    • 由于内部结构相对较复杂, 因此训练效率在同等算力下较传统RNN低很多.

4. GRU和Bi-GRU

GRU(Gated Recurrent Unit)也称门控循环单元结构, 它也是传统RNN的变体, 同LSTM一样能够有效捕捉长序列之间的语义关联, 缓解梯度消失或爆炸现象. 同时它的结构和计算要比LSTM更简单, 它的核心结构可以分为两个部分去解析:更新门、重置门