0


自然语言处理系列(一)——RNN基础

注: 本文是总结性文章,叙述较为简洁,不适合初学者

目录

一、为什么要有RNN?

普通的MLP无法处理序列信息(如文本、语音等),这是因为序列是不定长的,而MLP的输入层神经元个数是固定的。

二、RNN的结构

普通MLP的结构(以单隐层为例):

在这里插入图片描述

普通RNN(又称Vanilla RNN,接下来都将使用这一说法)的结构(在单隐层MLP的基础上进行改造):

在这里插入图片描述

  1. t
  2. t
  3. t 时刻隐藏层接收的输入来自于
  4. t
  5. 1
  6. t-1
  7. t1 时刻隐藏层的输出和
  8. t
  9. t
  10. t 时刻的样例输入。用数学公式表示,就是
  11. h
  12. (
  13. t
  14. )
  15. =
  16. tanh
  17. (
  18. W
  19. h
  20. (
  21. t
  22. 1
  23. )
  24. +
  25. U
  26. x
  27. (
  28. t
  29. )
  30. +
  31. b
  32. )
  33. ,
  34. o
  35. (
  36. t
  37. )
  38. =
  39. V
  40. h
  41. (
  42. t
  43. )
  44. +
  45. c
  46. ,
  47. y
  48. ^
  49. (
  50. t
  51. )
  52. =
  53. softmax
  54. (
  55. o
  56. (
  57. t
  58. )
  59. )
  60. h^{(t)}=\tanh(Wh^{(t-1)}+Ux^{(t)}+b),\quad o^{(t)}=Vh^{(t)}+c,\quad \hat{y}^{(t)}=\text{softmax}(o^{(t)})
  61. h(t)=tanh(Wh(t1)+Ux(t)+b),o(t)=Vh(t)+c,y^​(t)=softmax(o(t))

训练RNN的过程中,实际上就是在学习

  1. U
  2. ,
  3. V
  4. ,
  5. W
  6. ,
  7. b
  8. ,
  9. c
  10. U,V,W,b,c
  11. U,V,W,b,c 这些参数。

正向传播后,我们需要计算损失,设时间步

  1. t
  2. t
  3. t 处求得的损失为
  4. L
  5. (
  6. t
  7. )
  8. =
  9. L
  10. (
  11. t
  12. )
  13. (
  14. y
  15. ^
  16. (
  17. t
  18. )
  19. ,
  20. y
  21. (
  22. t
  23. )
  24. )
  25. L^{(t)}=L^{(t)}(\hat{y}^{(t)},y^{(t)})
  26. L(t)=L(t)(y^​(t),y(t)),则总的损失为
  27. L
  28. =
  29. t
  30. =
  31. 1
  32. T
  33. L
  34. (
  35. t
  36. )
  37. L=\sum_{t=1}^T L^{(t)}
  38. L=∑t=1TL(t)。

2.1 BPTT

BPTT(BackPropagation Through Time),通过时间反向传播是RNN训练过程中的一个术语。因为正向传播时是沿着时间流逝的方向进行的,而反向传播则是逆着时间进行的。

为方便后续推导,我们先改进一下符号表述:

  1. h
  2. (
  3. t
  4. )
  5. =
  6. tanh
  7. (
  8. W
  9. h
  10. h
  11. h
  12. (
  13. t
  14. 1
  15. )
  16. +
  17. W
  18. x
  19. h
  20. x
  21. (
  22. t
  23. )
  24. +
  25. b
  26. )
  27. ,
  28. o
  29. (
  30. t
  31. )
  32. =
  33. W
  34. h
  35. o
  36. h
  37. (
  38. t
  39. )
  40. +
  41. c
  42. ,
  43. y
  44. ^
  45. (
  46. t
  47. )
  48. =
  49. softmax
  50. (
  51. o
  52. (
  53. t
  54. )
  55. )
  56. h^{(t)}=\tanh(W_{hh}h^{(t-1)}+W_{xh}x^{(t)}+b),\quad o^{(t)}=W_{ho}h^{(t)}+c,\quad \hat{y}^{(t)}=\text{softmax}(o^{(t)})
  57. h(t)=tanh(Whhh(t1)+Wxhx(t)+b),o(t)=Whoh(t)+c,y^​(t)=softmax(o(t))

做一个水平方向的 concatenation:

  1. W
  2. =
  3. (
  4. W
  5. h
  6. h
  7. ,
  8. W
  9. x
  10. h
  11. )
  12. W=(W_{hh},W_{xh})
  13. W=(Whh​,Wxh​),为简便起见,省略偏置
  14. b
  15. b
  16. b,则有
  17. h
  18. (
  19. t
  20. )
  21. =
  22. tanh
  23. (
  24. W
  25. (
  26. h
  27. (
  28. t
  29. 1
  30. )
  31. x
  32. (
  33. t
  34. )
  35. )
  36. )
  37. h^{(t)}=\tanh\left(W \begin{pmatrix} h^{(t-1)} \\ x^{(t)} \end{pmatrix} \right)
  38. h(t)=tanh(W(h(t1)x(t)​))

,接下来我们将关注参数

  1. W
  2. W
  3. W 的学习。

注意到

  1. h
  2. (
  3. t
  4. )
  5. h
  6. (
  7. t
  8. 1
  9. )
  10. =
  11. tanh
  12. (
  13. W
  14. h
  15. h
  16. h
  17. (
  18. t
  19. 1
  20. )
  21. +
  22. W
  23. x
  24. h
  25. x
  26. (
  27. t
  28. )
  29. )
  30. W
  31. h
  32. h
  33. ,
  34. L
  35. W
  36. =
  37. t
  38. =
  39. 1
  40. T
  41. L
  42. (
  43. t
  44. )
  45. W
  46. \frac{\partial h^{(t)}}{\partial h^{(t-1)}}=\tanh'(W_{hh}h^{(t-1)}+W_{xh}x^{(t)})W_{hh},\quad \frac{\partial L}{\partial W}=\sum_{t=1}^T\frac{\partial L^{(t)}}{\partial W}
  47. ∂h(t−1)∂h(t)​=tanh′(Whh​h(t−1)+Wxh​x(t))Whh​,∂W∂L​=t=1∑T​∂W∂L(t)​

从而

  1. L
  2. (
  3. T
  4. )
  5. W
  6. =
  7. L
  8. (
  9. T
  10. )
  11. h
  12. (
  13. T
  14. )
  15. h
  16. (
  17. T
  18. )
  19. h
  20. (
  21. T
  22. 1
  23. )
  24. h
  25. (
  26. 2
  27. )
  28. h
  29. (
  30. 1
  31. )
  32. h
  33. (
  34. 1
  35. )
  36. W
  37. =
  38. L
  39. (
  40. T
  41. )
  42. h
  43. (
  44. T
  45. )
  46. t
  47. =
  48. 2
  49. T
  50. h
  51. (
  52. t
  53. )
  54. h
  55. (
  56. t
  57. 1
  58. )
  59. h
  60. (
  61. 1
  62. )
  63. W
  64. =
  65. L
  66. (
  67. T
  68. )
  69. h
  70. (
  71. T
  72. )
  73. (
  74. t
  75. =
  76. 2
  77. T
  78. tanh
  79. (
  80. W
  81. h
  82. h
  83. h
  84. (
  85. t
  86. 1
  87. )
  88. +
  89. W
  90. x
  91. h
  92. x
  93. (
  94. t
  95. )
  96. )
  97. )
  98. W
  99. h
  100. h
  101. T
  102. 1
  103. h
  104. (
  105. 1
  106. )
  107. W
  108. \begin{aligned} \frac{\partial L^{(T)}}{\partial W}&=\frac{\partial L^{(T)}}{\partial h^{(T)}}\cdot \frac{\partial h^{(T)}}{\partial h^{(T-1)}}\cdots \frac{\partial h^{(2)}}{\partial h^{(1)}}\cdot\frac{\partial h^{(1)}}{\partial W} \\ &=\frac{\partial L^{(T)}}{\partial h^{(T)}}\cdot \prod_{t=2}^T\frac{\partial h^{(t)}}{\partial h^{(t-1)}}\cdot\frac{\partial h^{(1)}}{\partial W}\\ &=\frac{\partial L^{(T)}}{\partial h^{(T)}}\cdot \left(\prod_{t=2}^T\tanh'(W_{hh}h^{(t-1)}+W_{xh}x^{(t)})\right)\cdot W_{hh}^{T-1} \cdot\frac{\partial h^{(1)}}{\partial W}\\ \end{aligned}
  109. ∂W∂L(T)​​=∂h(T)∂L(T)​⋅∂h(T−1)∂h(T)​⋯∂h(1)∂h(2)​⋅∂W∂h(1)​=∂h(T)∂L(T)​⋅t=2∏T​∂h(t−1)∂h(t)​⋅∂W∂h(1)​=∂h(T)∂L(T)​⋅(t=2∏T​tanh′(Whh​h(t−1)+Wxh​x(t)))⋅WhhT−1​⋅∂W∂h(1)​​

因为

  1. tanh
  2. (
  3. )
  4. \tanh'(\cdot)
  5. tanh′(⋅) 几乎总是小于
  6. 1
  7. 1
  8. 1 的,当
  9. T
  10. T
  11. T 足够大时将会出现梯度消失现象。

假如不采用非线性的激活函数,为简便起见,不妨设激活函数为恒等映射

  1. f
  2. (
  3. x
  4. )
  5. =
  6. x
  7. f(x)=x
  8. f(x)=x,于是有
  9. L
  10. (
  11. T
  12. )
  13. W
  14. =
  15. L
  16. (
  17. T
  18. )
  19. h
  20. (
  21. T
  22. )
  23. W
  24. h
  25. h
  26. T
  27. 1
  28. h
  29. (
  30. 1
  31. )
  32. W
  33. \frac{\partial L^{(T)}}{\partial W}=\frac{\partial L^{(T)}}{\partial h^{(T)}}\cdot W_{hh}^{T-1} \cdot\frac{\partial h^{(1)}}{\partial W}
  34. WL(T)​=∂h(T)∂L(T)​⋅WhhT1​⋅∂Wh(1)​
  • 当 W h h W_{hh} Whh​ 的最大奇异值大于 1 1 1 时,会出现梯度爆炸。
  • 当 W h h W_{hh} Whh​ 的最大奇异值小于 1 1 1 时,会出现梯度消失。

三、RNN的分类

按照输入和输出的结构可以对RNN进行如下分类:

  • 1 vs N(vec2seq):Image Captioning;
  • N vs 1(seq2vec):Sentiment Analysis;
  • N vs M(seq2seq):Machine Translation;
  • N vs N(seq2seq):Video Classification on frame level.

在这里插入图片描述

注意 1 vs 1 是传统的MLP。

若按照内部构造进行分类则会得到:

  • RNN、Bi-RNN、…
  • LSTM、Bi-LSTM、…
  • GRU、Bi-GRU、…

四、Vanilla RNN的优缺点

优点:

  • 可以处理不定长的序列;
  • 计算时会考虑历史信息;
  • 权重沿时间方向上是共享的;
  • 模型大小不会随着输入大小增加而改变。

缺点:

  • 计算效率低;
  • 梯度会消失/爆炸(后续将知道,避免梯度爆炸可采用梯度裁剪,避免梯度消失可换用其他的RNN结构,如LSTM);
  • 无法处理长序列(即不具备长记忆性);
  • 无法利用未来的输入(Bi-RNN可解决)。

五、Bidirectional RNN

许多时候,我们要输出的

  1. y
  2. (
  3. t
  4. )
  5. y^{(t)}
  6. y(t) 可能依赖于整个序列,因此需要使用双向RNNBRNN)。BRNN结合了时间上从序列起点开始移动的RNN和从序列末尾开始移动的RNN。两个RNN互相独立不共享权重:

在这里插入图片描述
相应的计算方式变为:

  1. h
  2. (
  3. t
  4. )
  5. =
  6. tanh
  7. (
  8. W
  9. 1
  10. h
  11. (
  12. t
  13. 1
  14. )
  15. +
  16. U
  17. 1
  18. x
  19. (
  20. t
  21. )
  22. +
  23. b
  24. 1
  25. )
  26. g
  27. (
  28. t
  29. )
  30. =
  31. tanh
  32. (
  33. W
  34. 2
  35. h
  36. (
  37. t
  38. 1
  39. )
  40. +
  41. U
  42. 2
  43. x
  44. (
  45. t
  46. )
  47. +
  48. b
  49. 2
  50. )
  51. o
  52. (
  53. t
  54. )
  55. =
  56. V
  57. (
  58. h
  59. (
  60. t
  61. )
  62. ;
  63. g
  64. (
  65. t
  66. )
  67. )
  68. +
  69. c
  70. y
  71. ^
  72. (
  73. t
  74. )
  75. =
  76. softmax
  77. (
  78. o
  79. (
  80. t
  81. )
  82. )
  83. \begin{aligned} &h^{(t)}=\tanh(W_1h^{(t-1)}+U_1x^{(t)}+b_1) \\ &g^{(t)}=\tanh(W_2h^{(t-1)}+U_2x^{(t)}+b_2) \\ &o^{(t)}=V(h^{(t)};g^{(t)})+c \\ &\hat{y}^{(t)}=\text{softmax}(o^{(t)}) \\ \end{aligned}
  84. h(t)=tanh(W1h(t1)+U1x(t)+b1​)g(t)=tanh(W2h(t1)+U2x(t)+b2​)o(t)=V(h(t);g(t))+cy^​(t)=softmax(o(t))​

其中

  1. (
  2. h
  3. (
  4. t
  5. )
  6. ;
  7. g
  8. (
  9. t
  10. )
  11. )
  12. (h^{(t)};g^{(t)})
  13. (h(t);g(t)) 代表将两个列向量
  14. h
  15. (
  16. t
  17. )
  18. h^{(t)}
  19. h(t)
  20. g
  21. (
  22. t
  23. )
  24. g^{(t)}
  25. g(t) 进行纵向连接。

事实上,若将

  1. V
  2. V
  3. V 按列分块,则上述的第三个等式还可写成:
  4. o
  5. (
  6. t
  7. )
  8. =
  9. V
  10. (
  11. h
  12. (
  13. t
  14. )
  15. ;
  16. g
  17. (
  18. t
  19. )
  20. )
  21. +
  22. c
  23. =
  24. (
  25. V
  26. 1
  27. ,
  28. V
  29. 2
  30. )
  31. (
  32. h
  33. (
  34. t
  35. )
  36. g
  37. (
  38. t
  39. )
  40. )
  41. +
  42. c
  43. =
  44. V
  45. 1
  46. h
  47. (
  48. t
  49. )
  50. +
  51. V
  52. 2
  53. g
  54. (
  55. t
  56. )
  57. +
  58. c
  59. o^{(t)}=V(h^{(t)};g^{(t)})+c= (V_1,V_2) \begin{pmatrix} h^{(t)} \\ g^{(t)} \end{pmatrix}+c=V_1h^{(t)}+V_2g^{(t)}+c
  60. o(t)=V(h(t);g(t))+c=(V1​,V2​)(h(t)g(t)​)+c=V1h(t)+V2g(t)+c

训练 BRNN 的过程实际就是在学习

  1. U
  2. 1
  3. ,
  4. U
  5. 2
  6. ,
  7. V
  8. ,
  9. W
  10. 1
  11. ,
  12. W
  13. 2
  14. ,
  15. b
  16. 1
  17. ,
  18. b
  19. 2
  20. ,
  21. c
  22. U_1,U_2,V,W_1,W_2,b_1,b_2,c
  23. U1​,U2​,V,W1​,W2​,b1​,b2​,c 这些参数。

六、Stacked RNN

堆叠RNN又称多层RNN或深度RNN,即由多个隐藏层组成。以双隐层单向RNN为例,其结构如下:

在这里插入图片描述

相应的计算过程如下:

  1. h
  2. (
  3. t
  4. )
  5. =
  6. tanh
  7. (
  8. W
  9. h
  10. h
  11. h
  12. (
  13. t
  14. 1
  15. )
  16. +
  17. W
  18. x
  19. h
  20. x
  21. (
  22. t
  23. )
  24. +
  25. b
  26. h
  27. )
  28. z
  29. (
  30. t
  31. )
  32. =
  33. tanh
  34. (
  35. W
  36. z
  37. z
  38. z
  39. (
  40. t
  41. 1
  42. )
  43. +
  44. W
  45. h
  46. z
  47. h
  48. (
  49. t
  50. )
  51. +
  52. b
  53. z
  54. )
  55. o
  56. (
  57. t
  58. )
  59. =
  60. W
  61. z
  62. o
  63. z
  64. (
  65. t
  66. )
  67. +
  68. b
  69. o
  70. y
  71. ^
  72. (
  73. t
  74. )
  75. =
  76. softmax
  77. (
  78. o
  79. (
  80. t
  81. )
  82. )
  83. \begin{aligned} &h^{(t)}=\tanh(W_{hh}h^{(t-1)}+W_{xh}x^{(t)}+b_h) \\ &z^{(t)}=\tanh(W_{zz}z^{(t-1)}+W_{hz}h^{(t)}+b_z) \\ &o^{(t)}=W_{zo}z^{(t)}+b_o \\ &\hat{y}^{(t)}=\text{softmax}(o^{(t)}) \\ \end{aligned}
  84. h(t)=tanh(Whhh(t1)+Wxhx(t)+bh​)z(t)=tanh(Wzzz(t1)+Whzh(t)+bz​)o(t)=Wzoz(t)+boy^​(t)=softmax(o(t))​

本文转载自: https://blog.csdn.net/raelum/article/details/124940953
版权归原作者 raelum 所有, 如有侵权,请联系我们删除。

“自然语言处理系列(一)——RNN基础”的评论:

还没有评论