0


28. 深度学习进阶 - LSTM

文章目录

在这里插入图片描述

Hi, 你好。我是茶桁。

我们上一节课,用了一个示例来展示了一下我们为什么要用RNN神经网络,它和全连接的神经网络具体有什么区别。

这节课,我们就着上一节课的内容继续往后讲,没看过上节课的,建议回头去好好看看,特别是对代码的进程顺序好好的弄清楚。

全连接的模型得很仔细的去改变它的结构,然后再给它加很多东西,效果才能变好:

  1. self.linear_with_tanh = nn.Sequential(
  2. nn.Linear(10, self.hidden_size),
  3. nn.Tanh(),
  4. nn.Linear(self.hidden_size, self.hidden_size),
  5. nn.Tanh(),
  6. nn.Linear(self.hidden_size, output_size))

但是对于RNN模型来说,我们只用了两个函数:

  1. self.rnn = nn.RNN(x_size, hidden_size, n_layers, batch_first=True)
  2. self.out = nn.Linear(hidden_size, output_size)

这是一个很本质的问题, 也比较重要。为什么RNN的模型这么简单,它的效果比更复杂的全连接要好呢?

这个和我们平时生活中做各种事情其实都很类似,他背后的原因是他的信息保留的更多。RNN模型厉害的本质是在运行的过程中把更多的信息记录下来,而全连接没有记录。

对于RNN模型,还有两个点大家需要注意。

第一个,有一种叫做stacked的RNN的模型。我们RNN模型每一次输出都有一个output和hidden,把outputs和hidden作为它的输入再传给另外一个RNN模型,模型就变得更复杂,理论上可以解决些更复杂的场景。我们把这种就叫做stacked RNN。

Alt text

还有一种形式,Bidirectional RNN,双向RNN。有一个很著名的文本模型Bert, 那个B就是双向的意思。

我们回过头来看上节课我们讲过的两种网络:

  1. h
  2. t
  3. =
  4. σ
  5. h
  6. (
  7. W
  8. h
  9. x
  10. t
  11. +
  12. U
  13. h
  14. h
  15. t
  16. 1
  17. +
  18. b
  19. h
  20. )
  21. y
  22. t
  23. =
  24. σ
  25. y
  26. (
  27. W
  28. y
  29. h
  30. t
  31. +
  32. b
  33. y
  34. )
  35. \begin{align*} h_t & = \sigma_h(W_hx_t + U_hh_{t-1} + b_h) \\ y_t & = \sigma_y(W_yh_t + b_y) \end{align*}
  36. htyt​​=σh​(Whxt​+Uhht1​+bh​)=σy​(Wyht​+by​)​

在这个里面,每一时刻的y_t只和y_{t-1} 有关系,如果把所有的x一次性给到模型的时候,其实我们在这里可以给它加一个东西:

  1. h
  2. t
  3. =
  4. σ
  5. h
  6. (
  7. W
  8. h
  9. x
  10. t
  11. +
  12. U
  13. h
  14. h
  15. t
  16. 1
  17. +
  18. V
  19. h
  20. h
  21. t
  22. +
  23. 1
  24. +
  25. b
  26. h
  27. )
  28. \begin{align*} h_t & = \sigma_h(W_hx_t + U_hh_{t-1} + V_h * h_{t+1} + b_h) \end{align*}
  29. ht​​=σh​(Whxt​+Uhht1​+Vh​∗ht+1​+bh​)​

还可以写成这样,那这样的话它实现的就是每一时刻的t既和前一次有关系
和后一刻有关系。这样我们每一次的值不仅和前面有关,还和后面有关。就叫做双向RNN。

Alt text

对于RNN来说,它有一个很严重的问题,就是之前说过的,它的vanishing和exploding的问题会很明显, 也就是梯度消失和爆炸问题。

在这里插入图片描述

想一下,现在如果有一个loss,那它最终的loss是不是对于{x1, x2, …, xn}都有关系,比方说现在要求

  1. l
  2. o
  3. s
  4. s
  5. w
  6. 1
  7. \frac{\partial loss}{\partial w_1}
  8. w1​∂loss​, 假如说现在h100 那这种调用关系就是
  9. l
  10. o
  11. s
  12. s
  13. w
  14. 1
  15. =
  16. h
  17. 100
  18. h
  19. 99
  20. h
  21. 99
  22. h
  23. 98
  24. .
  25. .
  26. .
  27. h
  28. 0
  29. w
  30. 1
  31. \begin{align*} \frac{\partial loss}{\partial w_1} = \frac{\partial h_{100}}{\partial h_{99}} \cdot \frac{\partial h_{99}}{\partial h_{98}} \cdot ... \cdot \frac{\partial h_{0}}{\partial w_{1}} \end{align*}
  32. w1​∂loss​=∂h99​∂h100​​⋅∂h98​∂h99​​⋅...⋅∂w1​∂h0​​​

loss对于w1求偏导的时候,其实loss最先接受的是离他最近的, 假如说是h100。h100调用了h99,h99调用h98,就这个调用过程,这一串东西会变得很长。

我们之前课程说过一些情况,怎么去解决这个问题呢?对于RNN模型来说梯度爆炸很好解决,就直接设定一个阈值就可以了,起码也是能学习的。

在这里插入图片描述

要讲的是想一种方法怎么样来解决梯度消失的问题。这个梯度消失的解决方法,就叫LSTM。要解决梯度消失,就是要用LSTM: Long Short-Term Memory,长短记忆模型,既能保持长信息,又能保持短信息。

在之前那个很长的过程中,怎么样能够让它不消散呢?LSTM的核心思想是通过门控机制来控制信息的流动和及已的更新,包含了Input Gate, Forget Gate,Cell State以及Output Gate。这些会一起协作来处理序列数据。

其中Input Gate控制着新信息的输入,以及信息对细胞状态的影响。 Forget Gate控制着细胞状态中哪些信息应该被易王,Cell State用于传递信息,是LSTM的核心,Output Gate控制着细胞状态如何影响输出。

这里每一个门控单元都由一个Sigmoid激活函数来控制信息的流动,以及一个Tanh激活函数来确定信息的值。

  1. I
  2. n
  3. p
  4. u
  5. t
  6. G
  7. a
  8. t
  9. e
  10. i
  11. t
  12. =
  13. σ
  14. (
  15. W
  16. i
  17. [
  18. h
  19. t
  20. 1
  21. ,
  22. x
  23. t
  24. ]
  25. +
  26. b
  27. i
  28. )
  29. C
  30. t
  31. =
  32. tanh
  33. (
  34. W
  35. c
  36. [
  37. h
  38. t
  39. 1
  40. ,
  41. x
  42. t
  43. ]
  44. +
  45. b
  46. c
  47. )
  48. C
  49. t
  50. =
  51. f
  52. t
  53. C
  54. t
  55. 1
  56. +
  57. i
  58. t
  59. C
  60. t
  61. F
  62. o
  63. r
  64. g
  65. e
  66. t
  67. G
  68. a
  69. t
  70. e
  71. f
  72. t
  73. =
  74. σ
  75. (
  76. W
  77. f
  78. [
  79. h
  80. t
  81. 1
  82. ,
  83. x
  84. t
  85. ]
  86. +
  87. b
  88. f
  89. )
  90. C
  91. t
  92. =
  93. f
  94. t
  95. C
  96. t
  97. 1
  98. +
  99. i
  100. t
  101. C
  102. t
  103. O
  104. u
  105. t
  106. p
  107. u
  108. t
  109. G
  110. a
  111. t
  112. e
  113. o
  114. t
  115. =
  116. σ
  117. (
  118. W
  119. o
  120. [
  121. h
  122. t
  123. 1
  124. ,
  125. x
  126. t
  127. ]
  128. +
  129. b
  130. o
  131. )
  132. h
  133. t
  134. =
  135. o
  136. t
  137. tanh
  138. (
  139. C
  140. t
  141. )
  142. \begin{align*} Input Gate \\ i_t & = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) \\ C't & = \tanh(W_c \cdot [h{t-1}, x_t] + b_c) \\ C_t & = f_t \cdot C_{t-1} + i_t \cdot C'_t \\ Forget Gate \\ f_t & = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) \\ C_t & = f_t \cdot C_{t-1} + i_t \cdot C'_t \\ Output Gate \\ o_t & = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) \\ h_t & = o_t \cdot \tanh(C_t) \end{align*}
  143. InputGateit​C′tCt​ForgetGateft​Ct​OutputGateot​ht​​=σ(Wi​⋅[ht−1​,xt​]+bi​)=tanh(Wc​⋅[ht−1,xt​]+bc​)=ft​⋅Ct−1​+it​⋅Ct′​=σ(Wf​⋅[ht−1​,xt​]+bf​)=ft​⋅Ct−1​+it​⋅Ct′​=σ(Wo​⋅[ht−1​,xt​]+bo​)=ot​⋅tanh(Ct​)​

其中,

  1. h
  2. t
  3. 1
  4. h_{t-1}
  5. ht1 是前一个时间步的隐藏状态,
  6. x
  7. t
  8. x_t
  9. xt 是当前时间步的输入,
  10. W
  11. i
  12. ,
  13. W
  14. f
  15. ,
  16. W
  17. o
  18. ,
  19. W
  20. c
  21. W_i, W_f, W_o, W_c
  22. Wi​,Wf​,Wo​,Wc 是权重矩阵,
  23. b
  24. i
  25. ,
  26. b
  27. f
  28. ,
  29. b
  30. o
  31. ,
  32. b
  33. c
  34. b_i, b_f, b_o, b_c
  35. bi​,bf​,bo​,bc 是偏置。

Alt text

LSTM输入的是一个序列数据,可以是文本、时间序列,音频信号等等。那每个时间步的输入是序列中的饿一个元素,比如一个单词、一个时间点的观测值等等。

假设我们有一个序列 x = [x1, x2, …, xt], 其中t就代表的是时间步。

xt进来的时候, 之前我们是只接收一个hidden state, 现在我们多接收了一个

  1. C
  2. t
  3. 1
  4. C_{t-1}
  5. Ct1​,这个就是我们的Cell,这一步的
  6. C
  7. t
  8. 1
  9. C_{t-1}
  10. Ct1​其实就是上一步的
  11. C
  12. t
  13. C_t
  14. Ct​。

在训练开始时,需要初始化LSTM单元的隐藏状态h0和细胞状态c0。通常我们初始化它们为全零向量。

最开始的时候,我们要进入Input Gate, 对于每个时间步t, 计算输入门的激活值

  1. i
  2. t
  3. i_t
  4. it​,控制新信息的输入。使用Sigmoid函数来计算输入门的值:
  5. i
  6. t
  7. =
  8. σ
  9. (
  10. W
  11. i
  12. [
  13. h
  14. t
  15. 1
  16. ,
  17. x
  18. t
  19. ]
  20. +
  21. b
  22. i
  23. )
  24. i_t = \sigma (W_i \cdot [h_{t-1}, x_t] + b_i)
  25. it​=σ(Wi​⋅[ht1​,xt​]+bi​)

然后,计算新的侯选值

  1. C
  2. t
  3. C'_t
  4. Ct′​, 这是在当前时间步考虑的新信息。使用tanh激活函数来计算侯选值:
  5. C
  6. t
  7. =
  8. t
  9. a
  10. n
  11. h
  12. (
  13. W
  14. c
  15. [
  16. h
  17. t
  18. 1
  19. ,
  20. x
  21. t
  22. ]
  23. +
  24. b
  25. c
  26. )
  27. C'_t = tanh(W_c \cdot [h_{t-1}, x_t] + b_c)
  28. Ct′​=tanh(Wc​⋅[ht1​,xt​]+bc​)

接下来我们就要更新细胞状态了,细胞状态

  1. C
  2. t
  3. C_t
  4. Ct​更新是通过遗忘门
  5. f
  6. t
  7. f_t
  8. ft​和输入门
  9. i
  10. t
  11. i_t
  12. it​控制的。遗忘门控制着哪些信息应该被遗忘,输入门控制新信息对细胞状态的影响:
  13. C
  14. t
  15. =
  16. f
  17. t
  18. C
  19. t
  20. 1
  21. +
  22. i
  23. t
  24. C
  25. t
  26. C_t = f_t \cdot C_{t-1} + i_t \cdot C'_t
  27. Ct​=ft​⋅Ct−1​+it​⋅Ct′​

那遗忘门决定哪些信息应该被遗忘,使用的就是Sigmoid函数计算遗忘门的激活值。

  1. f
  2. t
  3. =
  4. σ
  5. (
  6. W
  7. f
  8. [
  9. h
  10. t
  11. 1
  12. ,
  13. x
  14. t
  15. ]
  16. +
  17. b
  18. f
  19. )
  20. f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f)
  21. ft​=σ(Wf​⋅[ht1​,xt​]+bf​)

接着,计算输出门

  1. O
  2. t
  3. O_t
  4. Ot​, 控制着细胞状态如何影响输出和隐藏状态。一样,我们还是使用Sigmoid函数计算。
  5. o
  6. t
  7. =
  8. σ
  9. (
  10. W
  11. o
  12. [
  13. h
  14. t
  15. 1
  16. ,
  17. x
  18. t
  19. ]
  20. +
  21. b
  22. o
  23. )
  24. o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o)
  25. ot​=σ(Wo​⋅[ht1​,xt​]+bo​)

使用输出门的值

  1. o
  2. t
  3. o_t
  4. ot​来计算最终的隐藏状态
  5. h
  6. t
  7. h_t
  8. ht​和输出。 隐藏状态和输出都是根据细胞状态和输出门的值来计算的:
  9. h
  10. t
  11. =
  12. o
  13. t
  14. t
  15. a
  16. n
  17. h
  18. (
  19. C
  20. t
  21. )
  22. h_t = o_t \cdot tanh(C_t)
  23. ht​=ot​⋅tanh(Ct​)

接下来就容易了,我们迭代重复上述过程,处理序列中的每一个时间步,直到处理完整个序列。

LSTM的输出可以是隐藏状态

  1. h
  2. t
  3. h_t
  4. ht​, 也可以是细胞状态
  5. C
  6. t
  7. C_t
  8. Ct​, 具体是取决于应用的需求。

后来大家就发现了一种改进的LSTM,其中门控机制允许细胞状态窥视现前的细胞状态的信息,而不仅仅是根据当前时间步的输入和隐藏状态来决定。 这个机制在LSTM单源种引入了额外的权重和连接,以允许细胞状态在门控过程中访问现前的细胞状态,我们称之为窥视孔连接: Peephole connections。

  1. f
  2. t
  3. =
  4. σ
  5. (
  6. W
  7. f
  8. [
  9. C
  10. t
  11. 1
  12. ,
  13. h
  14. t
  15. 1
  16. ,
  17. x
  18. t
  19. ]
  20. +
  21. b
  22. f
  23. )
  24. i
  25. t
  26. =
  27. σ
  28. (
  29. W
  30. i
  31. [
  32. C
  33. t
  34. 1
  35. ,
  36. h
  37. t
  38. 1
  39. ,
  40. x
  41. t
  42. ]
  43. +
  44. b
  45. i
  46. )
  47. o
  48. t
  49. =
  50. σ
  51. (
  52. W
  53. o
  54. [
  55. C
  56. t
  57. 1
  58. ,
  59. h
  60. t
  61. 1
  62. ,
  63. x
  64. t
  65. ]
  66. +
  67. b
  68. o
  69. )
  70. \begin{align*} f_t = \sigma(W_f \cdot [C_{t-1}, h_{t-1}, x_t] + b_f) \\ i_t = \sigma(W_i \cdot [C_{t-1}, h_{t-1}, x_t] + b_i) \\ o_t = \sigma(W_o \cdot [C_{t-1}, h_{t-1}, x_t] + b_o) \\ \end{align*}
  71. ft​=σ(Wf​⋅[Ct1​,ht1​,xt​]+bf​)it​=σ(Wi​⋅[Ct1​,ht1​,xt​]+bi​)ot​=σ(Wo​⋅[Ct1​,ht1​,xt​]+bo​)​

之前,我们是xt和x_{t-1}决定的f,那现在又把c_{t-1}加上了。就是多加了一些信息。

除此之外它有一个方法GRU,这个是2014年提出来的,Geted Recurrent Unit,它是LSTM的一个简化版本。

它最核心的内容:

  1. h
  2. t
  3. =
  4. (
  5. 1
  6. z
  7. t
  8. )
  9. h
  10. t
  11. 1
  12. +
  13. z
  14. t
  15. h
  16. t
  17. \begin{align*} h_t = (1-z_t) \cdot h_{t-1} + z_t \cdot h'_t \end{align*}
  18. ht​=(1−zt​)⋅ht−1​+zt​⋅ht′​​

咱们刚刚是

  1. C
  2. t
  3. =
  4. f
  5. t
  6. C
  7. t
  8. 1
  9. +
  10. i
  11. t
  12. C
  13. t
  14. C_t = f_t \cdot C_{t-1} + i_t \cdot C'_t
  15. Ct​=ft​⋅Ct−1​+it​⋅Ct′​,也就是遗忘加上输入,那我们对过去保留越多的时候,

输入就会越小,那对过去保留越小的时候,输入就会越大。

所以既然f也是1-0,i也是0-1,f大的时候i就小,f小的时候i就大,那么能不能写成f=(1-i)?

于是,GRU就这样实现了, 它其实最核心的就做了这样一件事, f=(1-i)。

  1. z
  2. t
  3. =
  4. σ
  5. (
  6. W
  7. z
  8. [
  9. h
  10. t
  11. 1
  12. ,
  13. x
  14. t
  15. ]
  16. )
  17. r
  18. t
  19. =
  20. σ
  21. (
  22. W
  23. r
  24. [
  25. h
  26. t
  27. 1
  28. ,
  29. x
  30. t
  31. ]
  32. )
  33. h
  34. t
  35. =
  36. tanh
  37. (
  38. W
  39. [
  40. r
  41. t
  42. h
  43. t
  44. 1
  45. ,
  46. x
  47. t
  48. ]
  49. )
  50. h
  51. t
  52. =
  53. (
  54. 1
  55. z
  56. t
  57. )
  58. h
  59. t
  60. 1
  61. +
  62. z
  63. t
  64. h
  65. t
  66. \begin{align*} z_t & = \sigma(W_z \cdot [h_{t-1}, x_t]) \\ r_t & = \sigma(W_r \cdot [h_{t-1}, x_t]) \\ h'_t & = \tanh(W \cdot [r_t \cdot h_{t-1}, x_t]) \\ h_t & = (1-z_t) \cdot h_{t-1} + z_t \cdot h'_t \end{align*}
  67. ztrtht′​ht​​=σ(Wz​⋅[ht1​,xt​])=σ(Wr​⋅[ht1​,xt​])=tanh(W⋅[rt​⋅ht1​,xt​])=(1zt​)⋅ht1​+zt​⋅ht′​​

这个z其实和i是一样的东西,只是原作者为了发表论文方便而改了个名称。

https://arxiv.org/pdf/1406.1078v3.pdf

  1. r
  2. t
  3. r_t
  4. rt​是来控制上一时刻的
  5. h
  6. t
  7. h_t
  8. ht​在我们此时此刻的重要性、影响程度。那我们可以将
  9. r
  10. t
  11. h
  12. t
  13. 1
  14. r_t \cdot h_{t-1}
  15. rt​⋅ht1​看成是关于及已的,
  16. 1
  17. z
  18. t
  19. 1-z_t
  20. 1zt​也是关于记忆的。

GRU这样做之后有什么好处呢?

原来我们有三个门: f, i, o, 那现在变成了两个,z和r。为什么就更好了呢?我们在PyTorch里面往往用的是GRU。

大家想一下,是不是少了一个门其实就少了一个矩阵?我们看公式的时候,

  1. W
  2. f
  3. W_f
  4. Wf​是一个数学符号,但是在背后其实是一个矩阵,是一个矩阵的话少了一个矩阵意味着参数就少多了,运算就更快了等等。

但其实这些都不是最关键的,最关键的是减少过拟合了。我们之前的课程中一再强调,过拟合之所以产生,最主要的原因是数据不够或者说是模型太复杂。

但是在现有的数据情况下,为了让数据发挥出最大效力,你把需要训练的模型变简单,参数变少,就没有那么复杂了。

关于RNN模型,我们后面还会介绍一些具体的示例。


本文转载自: https://blog.csdn.net/ivandoo/article/details/134906944
版权归原作者 茶桁 所有, 如有侵权,请联系我们删除。

“28. 深度学习进阶 - LSTM”的评论:

还没有评论