0


LSTM模型计算详解

LSTM

写在前面

本文记录笔者在学习LSTM时的记录,相信读者已经在网上看过许多的LSTM博客与视频,与其他博客不同的是,本文会从数学公式的角度,剖析LSTM模型中各个部分的模型输入输出等维度信息,帮助初学者在公式层面理解LSTM模型,并且给出了相关计算的例子代入股票预测场景,并给出参考代码。

模型结构

LSTM的模型结构如下图所示。它由若干个重复的LSTM单元组成,每个单元内部包含遗忘门、输入门和输出门,以及当前时刻的单元状态和输出状态。

LSTM模型结构图

模型输入

LSTM模型,通常是处理一个序列(比如文本序列或时间序列)

  1. X
  2. =
  3. (
  4. x
  5. 1
  6. ,
  7. x
  8. 2
  9. ,
  10. ,
  11. x
  12. t
  13. ,
  14. )
  15. T
  16. X = (x_1,x_2,\dots,x_t,\dots)^T
  17. X=(x1​,x2​,…,xt​,…)T ,每个时间步的输入可以表示为
  18. x
  19. t
  20. x_t
  21. xt​,我们使用滑动窗口将序列分为若干个窗口大小为
  22. L
  23. L
  24. L的窗口,步长为
  25. s
  26. t
  27. e
  28. p
  29. step
  30. step,当数据划分到最后,若不足为
  31. L
  32. L
  33. L不能构成窗口时,缺少的数据使用pad填充,通常为0填充或使用最近数据填充。例如,假设我们有
  34. 29
  35. 29
  36. 29个时间步骤的输入,即
  37. x
  38. =
  39. (
  40. x
  41. 0
  42. ,
  43. x
  44. 1
  45. ,
  46. ,
  47. x
  48. 28
  49. )
  50. T
  51. \vec{x} = (x_0,x_1,\dots,x_{28})^T
  52. x=(x0​,x1​,…,x28​)T,且假设窗口大小为
  53. 10
  54. 10
  55. 10,步长
  56. s
  57. t
  58. e
  59. p
  60. step
  61. step也为
  62. 10
  63. 10
  64. 10我们将数据分成三个窗口,即分为
  65. x
  66. 1
  67. =
  68. (
  69. x
  70. 0
  71. ,
  72. x
  73. 1
  74. ,
  75. ,
  76. x
  77. 9
  78. )
  79. T
  80. \vec{x_1} = (x_0,x_1,\dots,x_{9})^T
  81. x1​​=(x0​,x1​,…,x9​)T
  82. x
  83. 2
  84. =
  85. (
  86. x
  87. 10
  88. ,
  89. x
  90. 11
  91. ,
  92. ,
  93. x
  94. 19
  95. )
  96. T
  97. \vec{x_2} = (x_{10},x_{11},\dots,x_{19})^T
  98. x2​​=(x10​,x11​,…,x19​)T
  99. x
  100. 3
  101. =
  102. (
  103. x
  104. 20
  105. ,
  106. x
  107. 21
  108. ,
  109. ,
  110. x
  111. 28
  112. ,
  113. x
  114. 29
  115. )
  116. T
  117. \vec{x_3} = (x_{20},x_{21},\dots,x_{28},x_{29})^T
  118. x3​​=(x20​,x21​,…,x28​,x29​)T

由于

  1. x
  2. 29
  3. x_{29}
  4. x29​的值不存在,我们将其值设为
  5. 0
  6. 0
  7. 0或者
  8. x
  9. 28
  10. x_{28}
  11. x28​的值,即
  12. x
  13. 3
  14. =
  15. (
  16. x
  17. 20
  18. ,
  19. x
  20. 21
  21. ,
  22. ,
  23. x
  24. 28
  25. ,
  26. 0
  27. )
  28. T
  29. \vec{x_3} = (x_{20},x_{21},\dots,x_{28}, 0)^T
  30. x3​​=(x20​,x21​,…,x28​,0)T或者
  31. x
  32. 3
  33. =
  34. (
  35. x
  36. 20
  37. ,
  38. x
  39. 21
  40. ,
  41. ,
  42. x
  43. 28
  44. ,
  45. x
  46. 28
  47. )
  48. T
  49. \vec{x_3} = (x_{20},x_{21},\dots,x_{28},x_{28})^T
  50. x3​​=(x20​,x21​,…,x28​,x28​)T

当步长

  1. s
  2. t
  3. e
  4. p
  5. step
  6. step
  7. 1
  8. 1
  9. 1时,通常不会出现上面的情况,这也是我们使用的最多的一种滑动窗口划分方案。

例如,对于一个时序序列

  1. X
  2. =
  3. {
  4. x
  5. 1
  6. ,
  7. x
  8. 2
  9. ,
  10. ,
  11. x
  12. 10
  13. }
  14. X = \{x_1, x_2, \ldots, x_{10}\}
  15. X={x1​,x2​,…,x10​},窗口大小
  16. L
  17. =
  18. 3
  19. L = 3
  20. L=3,滑动步长
  21. s
  22. t
  23. e
  24. p
  25. =
  26. 1
  27. step = 1
  28. step=1,滑动窗口划分结果为:
  29. x
  30. 1
  31. =
  32. (
  33. x
  34. 1
  35. ,
  36. x
  37. 2
  38. ,
  39. x
  40. 3
  41. )
  42. x
  43. 2
  44. =
  45. (
  46. x
  47. 2
  48. ,
  49. x
  50. 3
  51. ,
  52. x
  53. 4
  54. )
  55. x
  56. 3
  57. =
  58. (
  59. x
  60. 3
  61. ,
  62. x
  63. 4
  64. ,
  65. x
  66. 5
  67. )
  68. x
  69. 4
  70. =
  71. (
  72. x
  73. 4
  74. ,
  75. x
  76. 5
  77. ,
  78. x
  79. 6
  80. )
  81. x
  82. 5
  83. =
  84. (
  85. x
  86. 5
  87. ,
  88. x
  89. 6
  90. ,
  91. x
  92. 7
  93. )
  94. x
  95. 6
  96. =
  97. (
  98. x
  99. 6
  100. ,
  101. x
  102. 7
  103. ,
  104. x
  105. 8
  106. )
  107. x
  108. 7
  109. =
  110. (
  111. x
  112. 7
  113. ,
  114. x
  115. 8
  116. ,
  117. x
  118. 9
  119. )
  120. x
  121. 8
  122. =
  123. (
  124. x
  125. 8
  126. ,
  127. x
  128. 9
  129. ,
  130. x
  131. 10
  132. )
  133. \begin{aligned} \vec{x_1} & = (x_1, x_2, x_3) \\ \vec{x_2} & = (x_2, x_3, x_4) \\ \vec{x_3} & = (x_3, x_4, x_5) \\ \vec{x_4} & = (x_4, x_5, x_6) \\ \vec{x_5} & = (x_5, x_6, x_7) \\ \vec{x_6} & = (x_6, x_7, x_8) \\ \vec{x_7} & = (x_7, x_8, x_9) \\ \vec{x_8} & = (x_8, x_9, x_{10}) \end{aligned}
  134. x1​​x2​​x3​​x4​​x5​​x6​​x7​​x8​​​=(x1​,x2​,x3​)=(x2​,x3​,x4​)=(x3​,x4​,x5​)=(x4​,x5​,x6​)=(x5​,x6​,x7​)=(x6​,x7​,x8​)=(x7​,x8​,x9​)=(x8​,x9​,x10​)​

LSTM 单元的输入包含当前时刻的输入

  1. x
  2. t
  3. \vec{x_t}
  4. xt​​、上一时刻的输出状态
  5. h
  6. t
  7. 1
  8. h_{t-1}
  9. ht1​以及上一时刻的单元状态
  10. c
  11. t
  12. 1
  13. c_{t-1}
  14. ct1​。在进行运算第一层LSTM单元时,我们会手动初始化
  15. h
  16. 0
  17. h_0
  18. h0​、
  19. c
  20. 0
  21. c_0
  22. c0​,而在后面的LSTM的单元中
  23. h
  24. t
  25. 1
  26. h_{t-1}
  27. ht1​和
  28. c
  29. t
  30. 1
  31. c_{t-1}
  32. ct1​,都可以由上一次的LSTM单元获得。
  33. x
  34. t
  35. \vec{x_t}
  36. xt​​、
  37. h
  38. t
  39. 1
  40. h_{t-1}
  41. ht1​、
  42. c
  43. t
  44. 1
  45. c_{t-1}
  46. ct1​分别代表当前时刻的输入信息、上一时刻的输出信息以及上一时刻的记忆信息。其中,
  47. x
  48. t
  49. R
  50. m
  51. ×
  52. 1
  53. \vec{x_t} \in \mathbb{R}^{m \times 1}
  54. xt​​∈Rm×1
  55. m
  56. m
  57. m是输入序列处理后的窗口大小(长度),
  58. h
  59. t
  60. 1
  61. h_{t-1}
  62. ht1​上一时刻的输出状态,形状为
  63. h
  64. t
  65. 1
  66. R
  67. d
  68. ×
  69. 1
  70. h_{t-1} \in \mathbb{R}^{d \times 1}
  71. ht1​∈Rd×1
  72. d
  73. d
  74. dLSTM单元的隐藏状态大小,
  75. c
  76. t
  77. 1
  78. c_{t-1}
  79. ct1​是上一时刻的单元状态,形状为
  80. c
  81. t
  82. 1
  83. R
  84. d
  85. ×
  86. 1
  87. c_{t-1} \in \mathbb{R}^{d \times 1}
  88. ct1​∈Rd×1,与
  89. h
  90. t
  91. 1
  92. h_{t-1}
  93. ht1​具有相同的形状。

我们通常会把

  1. h
  2. t
  3. 1
  4. h_{t-1}
  5. ht1​和
  6. x
  7. t
  8. \vec{x_t}
  9. xt​​拼在一起形成更长的向量
  10. y
  11. t
  12. \vec{y_t}
  13. yt​​,我们通常竖着拼,即
  14. y
  15. t
  16. R
  17. (
  18. d
  19. +
  20. m
  21. )
  22. ×
  23. 1
  24. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1}
  25. yt​​∈R(d+m1 ,如公式下所示,然后
  26. y
  27. t
  28. \vec{y_t}
  29. yt​​会传入各个门。当采用多批次时,
  30. y
  31. t
  32. R
  33. (
  34. d
  35. +
  36. m
  37. )
  38. ×
  39. n
  40. \vec{y_t} \in \mathbb{R}^{(d + m) \times n}
  41. yt​​∈R(d+mn
  42. y
  43. t
  44. =
  45. [
  46. h
  47. t
  48. 1
  49. ;
  50. x
  51. t
  52. ]
  53. =
  54. [
  55. h
  56. t
  57. 1
  58. x
  59. t
  60. ]
  61. \vec{y_t} = [h_{t-1}; \vec{x_t}] = \left[{\begin{matrix} h_{t-1} \\ \vec{x_t} \end{matrix}}\right]
  62. yt​​=[ht1​;xt​​]=[ht1xt​​​]

遗忘门

遗忘门的输入为我们在模型输入中处理得到的

  1. X
  2. t
  3. X_t'
  4. Xt′​。我们将
  5. X
  6. t
  7. X_t'
  8. Xt′​与遗忘门中的权重矩阵
  9. W
  10. f
  11. W_f
  12. Wf​相乘再加上置偏值
  13. b
  14. f
  15. b_f
  16. bf​,得到结果
  17. M
  18. f
  19. M_f
  20. Mf​。然后对
  21. M
  22. f
  23. M_f
  24. Mf​取Sigmoid,得到遗忘门的输出
  25. f
  26. t
  27. f_t
  28. ft​,其形状与单元状态
  29. c
  30. t
  31. c_t
  32. ct​相同,即
  33. f
  34. t
  35. R
  36. d
  37. ×
  38. 1
  39. f_t \in \mathbb{R}^{d \times 1}
  40. ft​∈Rd×1,表示遗忘的程度。具体的计算公式如(\ref{LSTME02})所示。
  41. M
  42. f
  43. =
  44. W
  45. f
  46. y
  47. t
  48. +
  49. b
  50. f
  51. M_f = W_f\vec{y_t} + b_f
  52. Mf​=Wfyt​​+bf
  53. f
  54. t
  55. =
  56. σ
  57. (
  58. M
  59. f
  60. )
  61. =
  62. 1
  63. 1
  64. +
  65. e
  66. (
  67. W
  68. f
  69. y
  70. t
  71. +
  72. b
  73. f
  74. )
  75. f_t = \sigma(M_f) = \frac{1}{1 + e^{-(W_f\vec{y_t} + b_f)}}
  76. ft​=σ(Mf​)=1+e−(Wfyt​​+bf​)1

其中,

  1. y
  2. t
  3. R
  4. (
  5. d
  6. +
  7. m
  8. )
  9. ×
  10. 1
  11. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1}
  12. yt​​∈R(d+m1
  13. W
  14. f
  15. R
  16. d
  17. ×
  18. (
  19. d
  20. +
  21. m
  22. )
  23. W_f \in \mathbb{R}^{d \times (d + m)}
  24. Wf​∈Rd×(d+m),
  25. b
  26. f
  27. R
  28. d
  29. ×
  30. 1
  31. b_f \in \mathbb{R}^{d \times 1}
  32. bf​∈Rd×1
  33. f
  34. t
  35. R
  36. d
  37. ×
  38. 1
  39. f_t \in \mathbb{R}^{d \times 1}
  40. ft​∈Rd×1

在LSTM的许多门中,都使用Sigmoid函数,Sigmoid函数的绝大部分的值的取值范围为

  1. (
  2. 0
  3. ,
  4. 1
  5. )
  6. (0, 1)
  7. (0,1),这可以很有效的表示在Sigmoid函数的输入中哪些数据需要记忆,哪些数据需要遗忘的过程。当Sigmoid函数只越接近
  8. 0
  9. 0
  10. 0时表示遗忘,当接近
  11. 1
  12. 1
  13. 1时表示需要记忆。

输入门

输入门的输入为我们在模型输入中处理得到的

  1. y
  2. t
  3. \vec{y_t}
  4. yt​​,且
  5. y
  6. t
  7. R
  8. (
  9. d
  10. +
  11. m
  12. )
  13. ×
  14. 1
  15. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1 }
  16. yt​​∈R(d+m1。我们将
  17. y
  18. t
  19. \vec{y_t}
  20. yt​​与输入门中的权重矩阵
  21. W
  22. i
  23. W_i
  24. Wi​相乘再加上置偏值
  25. b
  26. i
  27. b_i
  28. bi​,得到结果
  29. M
  30. i
  31. M_i
  32. Mi​,然后对
  33. M
  34. i
  35. M_i
  36. Mi​取Sigmoid,得到输入门的输出
  37. i
  38. t
  39. i_t
  40. it​,表示输入的重要程度。具体的计算公式如下所示。
  41. M
  42. i
  43. =
  44. W
  45. i
  46. y
  47. t
  48. +
  49. b
  50. i
  51. M_i = W_i\vec{y_t} + b_i
  52. Mi​=Wiyt​​+bi
  53. i
  54. t
  55. =
  56. σ
  57. (
  58. M
  59. i
  60. )
  61. =
  62. 1
  63. 1
  64. +
  65. e
  66. (
  67. W
  68. i
  69. y
  70. t
  71. +
  72. b
  73. i
  74. )
  75. i_t = \sigma(M_i) = \frac{1}{1 + e^{-(W_i\vec{y_t} + b_i)}}
  76. it​=σ(Mi​)=1+e−(Wiyt​​+bi​)1

其中,

  1. y
  2. t
  3. R
  4. (
  5. d
  6. +
  7. m
  8. )
  9. ×
  10. n
  11. \vec{y_t} \in \mathbb{R}^{(d + m) \times n}
  12. yt​​∈R(d+mn
  13. W
  14. i
  15. R
  16. d
  17. ×
  18. (
  19. d
  20. +
  21. m
  22. )
  23. W_i \in \mathbb{R}^{d \times (d + m)}
  24. Wi​∈Rd×(d+m),
  25. b
  26. i
  27. R
  28. d
  29. ×
  30. 1
  31. b_i \in \mathbb{R}^{d \times 1}
  32. bi​∈Rd×1
  33. i
  34. t
  35. R
  36. d
  37. ×
  38. 1
  39. i_t \in \mathbb{R}^{d \times 1}
  40. it​∈Rd×1

输出门

输出门的输入为我们在模型输入中处理得到的

  1. y
  2. t
  3. \vec{y_t}
  4. yt​​,且
  5. y
  6. t
  7. R
  8. (
  9. d
  10. +
  11. m
  12. )
  13. ×
  14. 1
  15. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1 }
  16. yt​​∈R(d+m1。我们将
  17. y
  18. t
  19. \vec{y_t}
  20. yt​​与输出门中的权重矩阵
  21. W
  22. o
  23. W_o
  24. Wo​相乘再加上置偏值
  25. b
  26. o
  27. b_o
  28. bo​,得到结果
  29. M
  30. o
  31. M_o
  32. Mo​,然后对
  33. M
  34. o
  35. M_o
  36. Mo​取Sigmoid,得到输出门的输出
  37. o
  38. t
  39. o_t
  40. ot​,具体的计算公式如下所示。
  41. M
  42. o
  43. =
  44. W
  45. o
  46. y
  47. t
  48. +
  49. b
  50. o
  51. M_o = W_o\vec{y_t} + b_o
  52. Mo​=Woyt​​+bo
  53. o
  54. t
  55. =
  56. σ
  57. (
  58. M
  59. o
  60. )
  61. =
  62. 1
  63. 1
  64. +
  65. e
  66. (
  67. W
  68. o
  69. y
  70. t
  71. +
  72. b
  73. o
  74. )
  75. o_t = \sigma(M_o) = \frac{1}{1 + e^{-(W_o\vec{y_t} + b_o)}}
  76. ot​=σ(Mo​)=1+e−(Woyt​​+bo​)1

其中,

  1. y
  2. t
  3. R
  4. (
  5. d
  6. +
  7. m
  8. )
  9. ×
  10. 1
  11. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1}
  12. yt​​∈R(d+m1
  13. W
  14. o
  15. R
  16. d
  17. ×
  18. (
  19. d
  20. +
  21. m
  22. )
  23. W_o \in \mathbb{R}^{d \times (d + m)}
  24. Wo​∈Rd×(d+m),
  25. b
  26. o
  27. R
  28. d
  29. ×
  30. 1
  31. b_o \in \mathbb{R}^{d \times 1}
  32. bo​∈Rd×1
  33. o
  34. t
  35. R
  36. d
  37. ×
  38. 1
  39. o_t \in \mathbb{R}^{d \times 1}
  40. ot​∈Rd×1

当前输入单元状态

在计算

  1. c
  2. t
  3. c_t
  4. ct​之前,我们需要引入当前输入单元状态,并计算
  5. c
  6. t
  7. ~
  8. \tilde{c_t}
  9. ct​~​的值。
  10. c
  11. t
  12. ~
  13. \tilde{c_t}
  14. ct​~​是当前输入的单元状态,表示当前输入要保留多少内容到记忆中。我们将
  15. y
  16. t
  17. \vec{y_t}
  18. yt​​与当前时刻状态单元的权重矩阵
  19. W
  20. c
  21. W_c
  22. Wc​相乘再加上置偏值
  23. b
  24. c
  25. b_c
  26. bc​,得到结果
  27. M
  28. c
  29. M_c
  30. Mc​,然后对
  31. M
  32. c
  33. M_c
  34. Mc​取tanh,得到的输出
  35. c
  36. t
  37. ~
  38. \tilde{c_t}
  39. ct​~​。
  40. c
  41. t
  42. ~
  43. \tilde{c_t}
  44. ct​~​的计算如公式下所示。
  45. M
  46. c
  47. =
  48. W
  49. c
  50. y
  51. t
  52. +
  53. b
  54. c
  55. M_c = W_c\vec{y_t} + b_c
  56. Mc​=Wcyt​​+bc
  57. c
  58. t
  59. ~
  60. =
  61. tanh
  62. (
  63. M
  64. c
  65. )
  66. =
  67. e
  68. M
  69. c
  70. e
  71. M
  72. c
  73. e
  74. M
  75. c
  76. +
  77. e
  78. M
  79. c
  80. =
  81. (
  82. e
  83. W
  84. c
  85. y
  86. t
  87. +
  88. b
  89. c
  90. )
  91. e
  92. (
  93. W
  94. c
  95. y
  96. t
  97. +
  98. b
  99. c
  100. )
  101. (
  102. e
  103. W
  104. c
  105. y
  106. t
  107. +
  108. b
  109. c
  110. )
  111. +
  112. e
  113. (
  114. W
  115. c
  116. y
  117. t
  118. +
  119. b
  120. c
  121. )
  122. \tilde{c_t} = \text{tanh}(M_c) = \frac{e^{M_c}-e^{-M_c}}{e^{M_c}+e^{-M_c}} = \frac{(e^{W_c\vec{y_t} + b_c)}-e^{-(W_c\vec{y_t} + b_c)}}{(e^{W_c\vec{y_t} + b_c)}+e^{-(W_c\vec{y_t} + b_c)}}
  123. ct​~​=tanh(Mc​)=eMc​+eMceMc​−eMc​​=(eWcyt​​+bc​)+e−(Wcyt​​+bc​)(eWcyt​​+bc​)−e−(Wcyt​​+bc​)​

其中,

  1. y
  2. t
  3. R
  4. (
  5. d
  6. +
  7. m
  8. )
  9. ×
  10. 1
  11. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1}
  12. yt​​∈R(d+m1
  13. W
  14. c
  15. R
  16. d
  17. ×
  18. (
  19. d
  20. +
  21. m
  22. )
  23. W_c \in \mathbb{R}^{d \times (d + m)}
  24. Wc​∈Rd×(d+m),
  25. b
  26. c
  27. R
  28. d
  29. ×
  30. 1
  31. b_c \in \mathbb{R}^{d \times 1}
  32. bc​∈Rd×1
  33. c
  34. t
  35. ~
  36. R
  37. d
  38. ×
  39. 1
  40. \tilde{c_t} \in \mathbb{R}^{d \times 1}
  41. ct​~​∈Rd×1

当前输入单元状态中,使用了tanh函数,tanh函数的取值范围为

  1. (
  2. 1
  3. ,
  4. 1
  5. )
  6. (-1,1)
  7. (−1,1),当函数的值接近
  8. 1
  9. -1
  10. 1时代表着当前输入信息要被修正,当但函数值接近
  11. 1
  12. 1
  13. 1时,代码当前输入信息要被加强。

当前时刻单元状态

接下来我们进行当前时刻单元状态

  1. c
  2. t
  3. c_t
  4. ct​的计算。我们使用遗忘门和输入门得到的结果
  5. f
  6. t
  7. f_t
  8. ft​、
  9. i
  10. t
  11. i_t
  12. it​和上一时刻单元状态
  13. c
  14. t
  15. 1
  16. c_{t-1}
  17. ct1​来计算当前时刻单元状态
  18. c
  19. t
  20. c_t
  21. ct​。我们分别将
  22. f
  23. t
  24. f_t
  25. ft​、
  26. c
  27. t
  28. 1
  29. c_{t-1}
  30. ct1​按元素相乘,
  31. i
  32. t
  33. i_t
  34. it​和
  35. c
  36. t
  37. ~
  38. \tilde{c_t}
  39. ct​~​按元素相乘,然后再将两者相加得到我们的当前时刻单元状态
  40. c
  41. t
  42. c_t
  43. ct​。具体计算如公式下所示。
  44. c
  45. t
  46. =
  47. f
  48. t
  49. c
  50. t
  51. 1
  52. +
  53. i
  54. t
  55. c
  56. t
  57. ~
  58. c_t = f_t \circ c_{t-1} + i_t \circ \tilde{c_t}
  59. ct​=ft​∘ct1​+it​∘ct​~​

其中,

  1. f
  2. t
  3. R
  4. d
  5. ×
  6. 1
  7. f_t \in \mathbb{R}^{d \times 1}
  8. ft​∈Rd×1时遗忘门输出,
  9. i
  10. t
  11. R
  12. d
  13. ×
  14. 1
  15. i_t \in \mathbb{R}^{d \times 1}
  16. it​∈Rd×1是输入门输出,
  17. c
  18. t
  19. ~
  20. R
  21. d
  22. ×
  23. 1
  24. \tilde{c_{t}} \in \mathbb{R}^{d \times 1}
  25. ct​~​∈Rd×1是当前输入状态单元,
  26. c
  27. t
  28. 1
  29. R
  30. d
  31. ×
  32. 1
  33. c_{t-1} \in \mathbb{R}^{d \times 1}
  34. ct1​∈Rd×1 是上一时刻状态单元,
  35. \circ
  36. ∘表示 **按元素乘**。

模型输出

模型的输出是

  1. h
  2. t
  3. h_t
  4. ht​和当前时刻的单元状态
  5. c
  6. t
  7. c_t
  8. ct​,而
  9. h
  10. t
  11. h_t
  12. ht​由当前时刻的单元状态
  13. c
  14. t
  15. c_t
  16. ct​​和输出门的输出
  17. o
  18. t
  19. o_t
  20. ot​确定。我们将当前时刻的单元状态
  21. c
  22. t
  23. c_t
  24. ct​取 tanh得到
  25. d
  26. t
  27. d_t
  28. dt​,然后将
  29. d
  30. t
  31. d_t
  32. dt
  33. o
  34. t
  35. o_t
  36. ot​按元素相乘得到最后的
  37. h
  38. t
  39. h_t
  40. ht​,计算公式如下所示。通常,
  41. h
  42. t
  43. h_t
  44. ht​会进一步传递给模型的上层或者作为最终的预测结果。
  45. d
  46. t
  47. =
  48. tanh
  49. (
  50. c
  51. t
  52. )
  53. =
  54. e
  55. c
  56. t
  57. e
  58. c
  59. t
  60. e
  61. c
  62. t
  63. +
  64. e
  65. c
  66. t
  67. d_t = \text{tanh}(c_t) = \frac{e^{c_t}-e^{-c_t}}{e^{c_t}+e^{-c_t}}
  68. dt​=tanh(ct​)=ect​+ectect​−ect​​
  69. h
  70. t
  71. =
  72. o
  73. t
  74. d
  75. t
  76. h_t = o_t \circ d_t
  77. ht​=ot​∘dt

其中

  1. h
  2. t
  3. R
  4. d
  5. ×
  6. 1
  7. h_t \in \mathbb{R}^{d \times 1}
  8. ht​∈Rd×1 为当前层隐藏状态,
  9. o
  10. t
  11. R
  12. d
  13. ×
  14. 1
  15. o_t \in \mathbb{R}^{d \times 1}
  16. ot​∈Rd×1为输出门的输出,
  17. c
  18. t
  19. R
  20. d
  21. ×
  22. 1
  23. c_t \in \mathbb{R}^{d \times 1}
  24. ct​∈Rd×1为当前时刻状态单元。

日期开盘价收盘价最高价最低价4月23日3038.61183021.97753044.94383016.51684月24日3029.40283044.82233045.63993019.12384月25日3037.92723052.89993060.26343034.64994月26日3054.97933088.63573092.43003054.9793
Table: SH000001

简单的LSTM例子

接下来我们根据上面的模型结构中的计算方法来简单计算一个LSTM的例子。

我们以取中国A股上证指数(SH000001)2024年4月23日-25日共3个交易日的数据为例,取开盘价、收盘价、最高价、最低价作为特征,具体数据如表格所示。使用LSTM模型计算预测2024年4月26日的开盘价、收盘价、最高价、最低价,损失函数使用MSE。我们取隐藏层状态

  1. d
  2. d
  3. d的大小为
  4. 4
  5. 4
  6. 4,然后进行计算,预测下一天的数据。

我们把表格数据处理成

  1. x
  2. t
  3. x_t
  4. xt​的形式,也就是把每天的
  5. 4
  6. 4
  7. 4个特征,转换成
  8. m
  9. ×
  10. 1
  11. m \times 1
  12. m×1
  13. (
  14. 4
  15. ×
  16. 1
  17. )
  18. (4 \times 1)
  19. (4×1)的向量,然后我们得到以
  20. X
  21. X
  22. X的结果。
  23. X
  24. =
  25. (
  26. x
  27. 1
  28. ,
  29. x
  30. 2
  31. ,
  32. x
  33. 3
  34. )
  35. =
  36. [
  37. 3038.6118
  38. 3029.4028
  39. 3037.9272
  40. 3021.9775
  41. 3044.8223
  42. 3052.8999
  43. 3044.9438
  44. 3045.6399
  45. 3060.2634
  46. 3016.5168
  47. 3019.1238
  48. 3034.6499
  49. ]
  50. X = (\vec{x_1}, \vec{x_2}, \vec{x_3}) = \begin{bmatrix} 3038.6118 & 3029.4028 & 3037.9272 \\ 3021.9775 & 3044.8223 & 3052.8999 \\ 3044.9438 & 3045.6399 & 3060.2634 \\ 3016.5168 & 3019.1238 & 3034.6499 \\ \end{bmatrix}
  51. X=(x1​​,x2​​,x3​​)=​3038.61183021.97753044.94383016.51683029.40283044.82233045.63993019.12383037.92723052.89993060.26343034.6499​​

由于隐藏层大小为

  1. d
  2. =
  3. 4
  4. d = 4
  5. d=4,所以
  6. h
  7. 0
  8. h_0
  9. h0​、
  10. c
  11. 0
  12. c_0
  13. c0​的维度都是
  14. 4
  15. ×
  16. 1
  17. 4 \times 1
  18. 4×1,我们将
  19. h
  20. 0
  21. h_0
  22. h0​和
  23. c
  24. 0
  25. c_0
  26. c0​进行初始化为
  27. 0
  28. \vec{0}
  29. 0向量,即
  30. h
  31. 0
  32. =
  33. [
  34. 0
  35. ,
  36. 0
  37. ,
  38. 0
  39. ,
  40. 0
  41. ]
  42. T
  43. ,
  44. c
  45. 0
  46. =
  47. [
  48. 0
  49. ,
  50. 0
  51. ,
  52. 0
  53. ,
  54. 0
  55. ]
  56. T
  57. h_0 = [0, 0, 0, 0]^T, c_0 = [0, 0, 0, 0]^T
  58. h0​=[0,0,0,0]T,c0​=[0,0,0,0]T

随后我们初始化

  1. W
  2. f
  3. W_f
  4. Wf​、
  5. W
  6. i
  7. W_i
  8. Wi​、
  9. W
  10. c
  11. W_c
  12. Wc​、
  13. W
  14. o
  15. W_o
  16. Wo​(维度为
  17. d
  18. ×
  19. (
  20. d
  21. +
  22. m
  23. )
  24. d \times (d + m)
  25. d×(d+m),即
  26. 4
  27. ×
  28. 8
  29. 4 \times 8
  30. 4×8以及
  31. b
  32. f
  33. b_f
  34. bf​、
  35. b
  36. i
  37. b_i
  38. bi​、
  39. b
  40. c
  41. b_c
  42. bc​、
  43. b
  44. o
  45. b_o
  46. bo​,
  47. W
  48. W
  49. W的元素值
  50. [
  51. 0.0001
  52. ,
  53. 0.0001
  54. ]
  55. \in [-0.0001, 0.0001]
  56. ∈[−0.0001,0.0001],W是随机矩阵,如下所示。
  57. W
  58. f
  59. =
  60. [
  61. 0.0005
  62. 0.0010
  63. 0.0010
  64. 0.0004
  65. 0.0008
  66. 0.0006
  67. 0.0006
  68. 0.0007
  69. 0.0004
  70. 0.0009
  71. 0.0006
  72. 0.0009
  73. 0.0001
  74. 0.0004
  75. 0.0009
  76. 0.0003
  77. 0.0005
  78. 0.0006
  79. 0.0007
  80. 0.0003
  81. 0.0003
  82. 0.0001
  83. 0.0004
  84. 0.0006
  85. 0.0007
  86. 0.0008
  87. 0.0007
  88. 0.0006
  89. 0.0005
  90. 0.0003
  91. 0.0010
  92. 0.0002
  93. ]
  94. W_f = \begin{bmatrix} -0.0005 & -0.0010 & -0.0010 & -0.0004 & -0.0008 & -0.0006 & -0.0006 & -0.0007 \\ 0.0004 & -0.0009 & -0.0006 & 0.0009 & 0.0001 & 0.0004 & 0.0009 & 0.0003 \\ -0.0005 & -0.0006 & 0.0007 & -0.0003 & -0.0003 & 0.0001 & 0.0004 & 0.0006 \\ -0.0007 & -0.0008 & 0.0007 & -0.0006 & 0.0005 & -0.0003 & -0.0010 & -0.0002 \\ \end{bmatrix}
  95. Wf​=​−0.00050.00040.00050.0007​−0.00100.00090.00060.0008​−0.00100.00060.00070.0007​−0.00040.00090.00030.0006​−0.00080.00010.00030.0005​−0.00060.00040.00010.0003​−0.00060.00090.00040.0010​−0.00070.00030.00060.0002​​
  96. W
  97. i
  98. =
  99. [
  100. 0.0006
  101. 0.0001
  102. 0.0003
  103. 0.0002
  104. 0.0008
  105. 0.0000
  106. 0.0003
  107. 0.0003
  108. 0.0007
  109. 0.0002
  110. 0.0006
  111. 0.0001
  112. 0.0009
  113. 0.0005
  114. 0.0007
  115. 0.0005
  116. 0.0008
  117. 0.0004
  118. 0.0007
  119. 0.0008
  120. 0.0008
  121. 0.0010
  122. 0.0006
  123. 0.0009
  124. 0.0005
  125. 0.0010
  126. 0.0006
  127. 0.0002
  128. 0.0002
  129. 0.0006
  130. 0.0007
  131. 0.0002
  132. ]
  133. W_i = \begin{bmatrix} -0.0006 & -0.0001 & -0.0003 & 0.0002 & 0.0008 & 0.0000 & -0.0003 & -0.0003 \\ 0.0007 & -0.0002 & 0.0006 & 0.0001 & -0.0009 & -0.0005 & -0.0007 & -0.0005 \\ -0.0008 & 0.0004 & 0.0007 & -0.0008 & -0.0008 & 0.0010 & -0.0006 & -0.0009 \\ -0.0005 & 0.0010 & -0.0006 & -0.0002 & -0.0002 & 0.0006 & -0.0007 & 0.0002 \\ \end{bmatrix}
  134. Wi​=​−0.00060.00070.00080.0005​−0.00010.00020.00040.0010​−0.00030.00060.00070.00060.00020.00010.00080.00020.00080.00090.00080.00020.00000.00050.00100.0006​−0.00030.00070.00060.0007​−0.00030.00050.00090.0002​​
  135. W
  136. c
  137. =
  138. [
  139. 0.0001
  140. 0.0004
  141. 0.0000
  142. 0.0006
  143. 0.0006
  144. 0.0002
  145. 0.0003
  146. 0.0005
  147. 0.0002
  148. 0.0006
  149. 0.0005
  150. 0.0009
  151. 0.0002
  152. 0.0008
  153. 0.0003
  154. 0.0009
  155. 0.0002
  156. 0.0004
  157. 0.0000
  158. 0.0009
  159. 0.0003
  160. 0.0003
  161. 0.0006
  162. 0.0008
  163. 0.0007
  164. 0.0008
  165. 0.0009
  166. 0.0007
  167. 0.0002
  168. 0.0010
  169. 0.0006
  170. 0.0003
  171. ]
  172. W_c = \begin{bmatrix} 0.0001 & 0.0004 & 0.0000 & -0.0006 & -0.0006 & -0.0002 & 0.0003 & 0.0005 \\ -0.0002 & -0.0006 & 0.0005 & -0.0009 & 0.0002 & -0.0008 & -0.0003 & -0.0009 \\ 0.0002 & 0.0004 & 0.0000 & 0.0009 & 0.0003 & 0.0003 & 0.0006 & -0.0008 \\ -0.0007 & -0.0008 & 0.0009 & -0.0007 & 0.0002 & -0.0010 & -0.0006 & -0.0003 \\ \end{bmatrix}
  173. Wc​=​0.00010.00020.00020.00070.00040.00060.00040.00080.00000.00050.00000.0009​−0.00060.00090.00090.0007​−0.00060.00020.00030.0002​−0.00020.00080.00030.00100.00030.00030.00060.00060.00050.00090.00080.0003​​
  174. W
  175. o
  176. =
  177. [
  178. 0.0009
  179. 0.0005
  180. 0.0000
  181. 0.0001
  182. 0.0001
  183. 0.0004
  184. 0.0005
  185. 0.0007
  186. 0.0009
  187. 0.0005
  188. 0.0008
  189. 0.0009
  190. 0.0001
  191. 0.0004
  192. 0.0002
  193. 0.0004
  194. 0.0005
  195. 0.0004
  196. 0.0007
  197. 0.0008
  198. 0.0006
  199. 0.0008
  200. 0.0006
  201. 0.0010
  202. 0.0002
  203. 0.0008
  204. 0.0008
  205. 0.0002
  206. 0.0008
  207. 0.0004
  208. 0.0008
  209. 0.0002
  210. ]
  211. W_o = \begin{bmatrix} -0.0009 & -0.0005 & 0.0000 & 0.0001 & -0.0001 & -0.0004 & -0.0005 & -0.0007 \\ 0.0009 & -0.0005 & 0.0008 & -0.0009 & 0.0001 & 0.0004 & -0.0002 & 0.0004 \\ -0.0005 & -0.0004 & 0.0007 & -0.0008 & -0.0006 & 0.0008 & 0.0006 & 0.0010 \\ -0.0002 & 0.0008 & 0.0008 & -0.0002 & 0.0008 & -0.0004 & 0.0008 & -0.0002 \\ \end{bmatrix}
  212. Wo​=​−0.00090.00090.00050.0002​−0.00050.00050.00040.00080.00000.00080.00070.00080.00010.00090.00080.0002​−0.00010.00010.00060.0008​−0.00040.00040.00080.0004​−0.00050.00020.00060.0008​−0.00070.00040.00100.0002​​
  213. b
  214. b
  215. b全部初始化为单位列向量即
  216. b
  217. f
  218. =
  219. b
  220. i
  221. =
  222. b
  223. c
  224. =
  225. b
  226. o
  227. =
  228. [
  229. 1
  230. 1
  231. 1
  232. 1
  233. ]
  234. T
  235. b_f = b_i = b_c = b_o = \begin{bmatrix} 1 \\ 1 \\ 1 \\ 1 \end{bmatrix}^T
  236. bf​=bi​=bc​=bo​=​1111​​T

然后我们将

  1. h
  2. 0
  3. h_0
  4. h0​与
  5. x
  6. 1
  7. x_1
  8. x1​拼在一起作为
  9. y
  10. 1
  11. \vec{y_1}
  12. y1​​,即
  13. y
  14. 1
  15. =
  16. [
  17. h
  18. 0
  19. ;
  20. x
  21. 1
  22. ]
  23. =
  24. [
  25. 0
  26. 0
  27. 0
  28. 0
  29. 3038.6118
  30. 3021.9775
  31. 3044.9438
  32. 3016.5168
  33. ]
  34. T
  35. \vec{y_1} = [h_0; \vec{x_1}] = \begin{bmatrix} 0 & 0 & 0 & 0 & 3038.6118 & 3021.9775 & 3044.9438 & 3016.5168 \end{bmatrix}^T
  36. y1​​=[h0​;x1​​]=[00003038.61183021.97753044.94383016.5168​]T

我们依次计算遗忘门

  1. f
  2. 1
  3. f_1
  4. f1​,输入门
  5. i
  6. 1
  7. i_1
  8. i1​,输出门
  9. o
  10. 1
  11. o_1
  12. o1​,即
  13. f
  14. 1
  15. =
  16. σ
  17. (
  18. W
  19. f
  20. y
  21. 1
  22. +
  23. b
  24. f
  25. )
  26. =
  27. [
  28. 0.0008
  29. 0.9985
  30. 0.9713
  31. 0.1164
  32. ]
  33. ,
  34. i
  35. 1
  36. =
  37. σ
  38. (
  39. W
  40. i
  41. y
  42. 1
  43. +
  44. b
  45. i
  46. )
  47. =
  48. [
  49. 0.8514
  50. 0.0010
  51. 0.0568
  52. 0.6491
  53. ]
  54. ,
  55. o
  56. 1
  57. =
  58. σ
  59. (
  60. W
  61. o
  62. y
  63. 1
  64. +
  65. b
  66. o
  67. )
  68. =
  69. [
  70. 0.0198
  71. 0.9577
  72. 0.9981
  73. 0.9842
  74. ]
  75. f_1 = \sigma(W_f\vec{y_1} + b_f) = \begin{bmatrix} 0.0008 \\ 0.9985 \\ 0.9713 \\ 0.1164 \end{bmatrix}, i_1 = \sigma(W_i\vec{y_1} + b_i) = \begin{bmatrix} 0.8514 \\ 0.0010 \\ 0.0568 \\ 0.6491 \end{bmatrix}, o_1 = \sigma(W_o\vec{y_1} + b_o) = \begin{bmatrix} 0.0198 \\ 0.9577 \\ 0.9981 \\ 0.9842 \end{bmatrix}
  76. f1​=σ(Wfy1​​+bf​)=​0.00080.99850.97130.1164​​,i1​=σ(Wiy1​​+bi​)=​0.85140.00100.05680.6491​​,o1​=σ(Woy1​​+bo​)=​0.01980.95770.99810.9842​​

随后我们进行计算当前输入单元状态

  1. c
  2. 1
  3. ~
  4. \tilde{c_1}
  5. c1​~​,即
  6. c
  7. 1
  8. ~
  9. =
  10. tanh
  11. (
  12. W
  13. c
  14. y
  15. 1
  16. +
  17. b
  18. c
  19. )
  20. =
  21. [
  22. 0.7923
  23. 0.9997
  24. 0.9805
  25. 0.9994
  26. ]
  27. T
  28. \tilde{c_1} = \text{tanh}(W_c\vec{y_1} + b_c) = \begin{bmatrix} 0.7923 & -0.9997 & 0.9805 & -0.9994 \end{bmatrix}^T
  29. c1​~​=tanh(Wcy1​​+bc​)=[0.7923​−0.99970.9805​−0.9994​]T

接着我们计算当前时刻单元状态

  1. c
  2. 1
  3. c_1
  4. c1​,即
  5. c
  6. 1
  7. =
  8. f
  9. 1
  10. c
  11. 0
  12. +
  13. i
  14. 1
  15. c
  16. 1
  17. ~
  18. =
  19. [
  20. 0.0008
  21. 0.9985
  22. 0.9713
  23. 0.1164
  24. ]
  25. [
  26. 0
  27. 0
  28. 0
  29. 0
  30. ]
  31. +
  32. [
  33. 0.8514
  34. 0.0010
  35. 0.0568
  36. 0.6491
  37. ]
  38. [
  39. 0.7923
  40. 0.9997
  41. 0.9805
  42. 0.9994
  43. ]
  44. =
  45. [
  46. 0.6746
  47. 0.001
  48. 0.0557
  49. 0.6488
  50. ]
  51. c_1 = f_1 \circ c_{0} + i_1 \circ \tilde{c_1} = \begin{bmatrix} 0.0008 \\ 0.9985 \\ 0.9713 \\ 0.1164 \end{bmatrix} \circ \begin{bmatrix} 0 \\ 0 \\ 0 \\ 0 \end{bmatrix} + \begin{bmatrix} 0.8514 \\ 0.0010 \\ 0.0568 \\ 0.6491 \end{bmatrix} \circ \begin{bmatrix} 0.7923 \\ -0.9997 \\ 0.9805 \\ -0.9994 \end{bmatrix} = \begin{bmatrix} 0.6746 \\ -0.001 \\ 0.0557 \\ -0.6488 \end{bmatrix}
  52. c1​=f1​∘c0​+i1​∘c1​~​=​0.00080.99850.97130.1164​​∘​0000​​+​0.85140.00100.05680.6491​​∘​0.79230.99970.98050.9994​​=​0.67460.0010.05570.6488​​

最后我们计算当前层隐藏层输出

  1. h
  2. 1
  3. h_1
  4. h1​,即
  5. h
  6. 1
  7. =
  8. o
  9. 1
  10. d
  11. 1
  12. =
  13. o
  14. 1
  15. tanh
  16. (
  17. c
  18. 1
  19. )
  20. =
  21. [
  22. 0.0116
  23. 0.001
  24. 0.0556
  25. 0.5618
  26. ]
  27. T
  28. h_1 = o_1 \circ d_1 = o_1 \circ \text{tanh}(c_1) = \begin{bmatrix} 0.0116 & -0.001 & 0.0556 & -0.5618 \end{bmatrix}^T
  29. h1​=o1​∘d1​=o1​∘tanh(c1​)=[0.0116​−0.0010.0556​−0.5618​]T

这样我们就完成了一次LSTM单元的正向传播计算,我们得到了

  1. h
  2. 1
  3. h_1
  4. h1​和
  5. c
  6. 1
  7. c_1
  8. c1​,我们将其传入下一层。

同理我们可以进行接下来 **第

  1. 2
  2. 2
  3. 2个交易日** 的计算。

我们将

  1. h
  2. 1
  3. h_1
  4. h1​与
  5. x
  6. 2
  7. \vec{x_2}
  8. x2​​拼在一起作为
  9. y
  10. 2
  11. \vec{y_2}
  12. y2​​,即
  13. y
  14. 2
  15. =
  16. [
  17. h
  18. 1
  19. ;
  20. x
  21. 2
  22. ]
  23. =
  24. [
  25. 0.0116
  26. 0.001
  27. 0.0556
  28. 0.5618
  29. 3029.4028
  30. 3044.8223
  31. 3045.6399
  32. 3019.1238
  33. ]
  34. T
  35. \vec{y_2} = [h_1; \vec{x_2}] = \begin{bmatrix} 0.0116 & -0.001 & 0.0556 & -0.5618 & 3029.4028 & 3044.8223 & 3045.6399 & 3019.1238 \end{bmatrix}^T
  36. y2​​=[h1​;x2​​]=[0.0116​−0.0010.0556​−0.56183029.40283044.82233045.63993019.1238​]T

我们依次计算遗忘门

  1. f
  2. 2
  3. f_2
  4. f2​,输入门
  5. i
  6. 2
  7. i_2
  8. i2​,输出门
  9. o
  10. 2
  11. o_2
  12. o2​,即
  13. f
  14. 2
  15. =
  16. σ
  17. (
  18. W
  19. f
  20. y
  21. 2
  22. +
  23. b
  24. f
  25. )
  26. =
  27. [
  28. 0.0008
  29. 0.9985
  30. 0.9715
  31. 0.1151
  32. ]
  33. ,
  34. i
  35. 2
  36. =
  37. σ
  38. (
  39. W
  40. i
  41. y
  42. 2
  43. +
  44. b
  45. i
  46. )
  47. =
  48. [
  49. 0.8503
  50. 0.0010
  51. 0.0583
  52. 0.6527
  53. ]
  54. ,
  55. o
  56. 2
  57. =
  58. σ
  59. (
  60. W
  61. o
  62. y
  63. 2
  64. +
  65. b
  66. o
  67. )
  68. =
  69. [
  70. 0.0196
  71. 0.9581
  72. 0.9981
  73. .
  74. 9839
  75. ]
  76. f_2 = \sigma(W_f\vec{y_2} + b_f) = \begin{bmatrix} 0.0008 \\ 0.9985 \\ 0.9715 \\ 0.1151 \end{bmatrix}, i_2 = \sigma(W_i\vec{y_2} + b_i) = \begin{bmatrix} 0.8503 \\ 0.0010 \\ 0.0583 \\ 0.6527 \end{bmatrix}, o_2 = \sigma(W_o\vec{y_2} + b_o) = \begin{bmatrix} 0.0196 \\ 0.9581 \\ 0.9981 \\.9839 \end{bmatrix}
  77. f2​=σ(Wfy2​​+bf​)=​0.00080.99850.97150.1151​​,i2​=σ(Wiy2​​+bi​)=​0.85030.00100.05830.6527​​,o2​=σ(Woy2​​+bo​)=​0.01960.95810.9981.9839​​

随后我们进行计算当前输入单元状态

  1. c
  2. 2
  3. ~
  4. \tilde{c_2}
  5. c2​~​,即
  6. c
  7. 2
  8. ~
  9. =
  10. tanh
  11. (
  12. W
  13. c
  14. y
  15. 2
  16. +
  17. b
  18. c
  19. )
  20. =
  21. [
  22. 0.7935
  23. 0.9998
  24. 0.9806
  25. 0.9994
  26. ]
  27. T
  28. \tilde{c_2} = \text{tanh}(W_c\vec{y_2} + b_c) = \begin{bmatrix} 0.7935 & -0.9998 & 0.9806 & -0.9994 \end{bmatrix}^T
  29. c2​~​=tanh(Wcy2​​+bc​)=[0.7935​−0.99980.9806​−0.9994​]T

接着我们计算当前时刻单元状态

  1. c
  2. 2
  3. c_2
  4. c2​,即
  5. c
  6. 2
  7. =
  8. f
  9. 2
  10. c
  11. 1
  12. +
  13. i
  14. 2
  15. c
  16. 2
  17. ~
  18. =
  19. [
  20. 0.6747
  21. 0.0010
  22. 0.0571
  23. 0.6524
  24. ]
  25. T
  26. c_2 = f_2 \circ c_{1} + i_2 \circ \tilde{c_2} = \begin{bmatrix} 0.6747 & -0.0010 & 0.0571 & -0.6524 \end{bmatrix}^T
  27. c2​=f2​∘c1​+i2​∘c2​~​=[0.6747​−0.00100.0571​−0.6524​]T

最后我们计算当前层隐藏层输出

  1. h
  2. 2
  3. h_2
  4. h2​,即
  5. h
  6. 2
  7. =
  8. o
  9. 2
  10. d
  11. 2
  12. =
  13. o
  14. 2
  15. tanh
  16. (
  17. c
  18. 2
  19. )
  20. =
  21. [
  22. 0.0115
  23. 0.0010
  24. 0.0570
  25. 0.5640
  26. ]
  27. T
  28. h_2 = o_2 \circ d_2 = o_2 \circ \text{tanh}(c_2) = \begin{bmatrix} 0.0115 & -0.0010 & 0.0570 & -0.5640 \end{bmatrix}^T
  29. h2​=o2​∘d2​=o2​∘tanh(c2​)=[0.0115​−0.00100.0570​−0.5640​]T

同理我们可以进行接下来 **第

  1. 3
  2. 3
  3. 3个交易日** 的计算。

我们将

  1. h
  2. 2
  3. h_2
  4. h2​与
  5. x
  6. 3
  7. \vec{x_3}
  8. x3​​拼在一起作为
  9. y
  10. 3
  11. \vec{y_3}
  12. y3​​,即
  13. y
  14. 3
  15. =
  16. [
  17. h
  18. 2
  19. ;
  20. x
  21. 3
  22. ]
  23. =
  24. [
  25. 0.0115
  26. 0.0010
  27. 0.0570
  28. 0.5640
  29. 3037.9272
  30. 3052.8999
  31. 3060.2634
  32. 3034.6499
  33. ]
  34. T
  35. \vec{y_3} = [h_2; \vec{x_3}] = \begin{bmatrix} 0.0115 & -0.0010 & 0.0570 & -0.5640 & 3037.9272 & 3052.8999 & 3060.2634 & 3034.6499 \end{bmatrix}^T
  36. y3​​=[h2​;x3​​]=[0.0115​−0.00100.0570​−0.56403037.92723052.89993060.26343034.6499​]T

我们依次计算遗忘门

  1. f
  2. 3
  3. f_3
  4. f3​,输入门
  5. i
  6. 3
  7. i_3
  8. i3​,输出门
  9. o
  10. 3
  11. o_3
  12. o3​。
  13. f
  14. 3
  15. =
  16. σ
  17. (
  18. W
  19. f
  20. y
  21. 3
  22. +
  23. b
  24. f
  25. )
  26. =
  27. [
  28. 0.0008
  29. 0.9985
  30. 0.9719
  31. 0.1135
  32. ]
  33. ,
  34. i
  35. 3
  36. =
  37. σ
  38. (
  39. W
  40. i
  41. y
  42. 3
  43. +
  44. b
  45. i
  46. )
  47. =
  48. [
  49. 0.8501
  50. 0.0010
  51. 0.0572
  52. 0.6518
  53. ]
  54. ,
  55. o
  56. 3
  57. =
  58. σ
  59. (
  60. W
  61. o
  62. y
  63. 3
  64. +
  65. b
  66. o
  67. )
  68. =
  69. [
  70. 0.0192
  71. 0.9584
  72. 0.9982
  73. 0.9841
  74. ]
  75. f_3 = \sigma(W_f\vec{y_3} + b_f) = \begin{bmatrix} 0.0008 \\ 0.9985 \\ 0.9719 \\ 0.1135 \end{bmatrix}, i_3 = \sigma(W_i\vec{y_3} + b_i) = \begin{bmatrix} 0.8501 \\ 0.0010 \\ 0.0572 \\ 0.6518 \end{bmatrix}, o_3 = \sigma(W_o\vec{y_3} + b_o) = \begin{bmatrix} 0.0192 \\ 0.9584 \\ 0.9982 \\ 0.9841 \end{bmatrix}
  76. f3​=σ(Wfy3​​+bf​)=​0.00080.99850.97190.1135​​,i3​=σ(Wiy3​​+bi​)=​0.85010.00100.05720.6518​​,o3​=σ(Woy3​​+bo​)=​0.01920.95840.99820.9841​​

随后我们进行计算当前输入单元状态

  1. c
  2. 3
  3. ~
  4. \tilde{c_3}
  5. c3​~​,即
  6. c
  7. 3
  8. ~
  9. =
  10. tanh
  11. (
  12. W
  13. c
  14. y
  15. 3
  16. +
  17. b
  18. c
  19. )
  20. =
  21. [
  22. 0.7956
  23. 0.9998
  24. 0.9807
  25. 0.9994
  26. ]
  27. T
  28. \tilde{c_3} = \text{tanh}(W_c\vec{y_3} + b_c) = \begin{bmatrix} 0.7956 & -0.9998 & 0.9807 & -0.9994 \end{bmatrix}^T
  29. c3​~​=tanh(Wcy3​​+bc​)=[0.7956​−0.99980.9807​−0.9994​]T

接着我们计算当前时刻单元状态

  1. c
  2. 3
  3. c_3
  4. c3​,即
  5. c
  6. 3
  7. =
  8. f
  9. 3
  10. c
  11. 2
  12. +
  13. i
  14. 3
  15. c
  16. 3
  17. ~
  18. =
  19. [
  20. 0.6763
  21. 0.0010
  22. 0.0561
  23. 0.6515
  24. ]
  25. T
  26. c_3 = f_3 \circ c_{2} + i_3 \circ \tilde{c_3} = \begin{bmatrix} 0.6763 & -0.0010 & 0.0561 & -0.6515 \end{bmatrix}^T
  27. c3​=f3​∘c2​+i3​∘c3​~​=[0.6763​−0.00100.0561​−0.6515​]T

最后我们计算当前层隐藏层输出

  1. h
  2. 3
  3. h_3
  4. h3​,即
  5. h
  6. 3
  7. =
  8. o
  9. 3
  10. d
  11. 3
  12. =
  13. o
  14. 3
  15. tanh
  16. (
  17. c
  18. 3
  19. )
  20. =
  21. [
  22. 0.0113
  23. 0.0010
  24. 0.0559
  25. 0.5636
  26. ]
  27. T
  28. h_3 = o_3 \circ d_3 = o_3 \circ \text{tanh}(c_3) = \begin{bmatrix} 0.0113 & -0.0010 & 0.0559 & -0.5636 \end{bmatrix}^T
  29. h3​=o3​∘d3​=o3​∘tanh(c3​)=[0.0113​−0.00100.0559​−0.5636​]T

得到了

  1. h
  2. 3
  3. h_3
  4. h3​之后,我们可以简单将
  5. h
  6. 3
  7. h_3
  8. h3​的结果作为预测的结果,然后使用MSE进行计算损失,MSE的计算公式如下所示。
  9. MSE
  10. =
  11. 1
  12. n
  13. i
  14. =
  15. 1
  16. n
  17. (
  18. y
  19. i
  20. ^
  21. y
  22. i
  23. )
  24. 2
  25. \text{MSE} = \frac{1}{n} \sum_{i = 1}^{n} (\hat{y_i} - y_i )^2
  26. MSE=n1i=1n​(yi​^​−yi​)2
  27. MSE
  28. =
  29. 1
  30. 4
  31. [
  32. (
  33. 3054.9793
  34. 0.0113
  35. )
  36. 2
  37. +
  38. (
  39. 3088.6357
  40. +
  41. 0.0010
  42. )
  43. 2
  44. +
  45. (
  46. 3092.43
  47. 0.0559
  48. )
  49. 2
  50. +
  51. (
  52. 3054.9793
  53. +
  54. 0.5636
  55. )
  56. 2
  57. ]
  58. =
  59. 9437756.3022
  60. \text{MSE} = \frac{1}{4} [(3054.9793 - 0.0113)^2 + (3088.6357 + 0.0010)^2 + ( 3092.43 - 0.0559)^2 + (3054.9793 + 0.5636)^2 ] \\ = 9437756.3022
  61. MSE=41​[(3054.97930.0113)2+(3088.6357+0.0010)2+(3092.430.0559)2+(3054.9793+0.5636)2]=9437756.3022

然后我们就得到我们的损失为

  1. 9437756.3022
  2. 9437756.3022
  3. 9437756.3022

以上就完成了一次将LSTM用于预测的计算。可以看到误差很大,实际应用中会先将数据输入到LSTM前,会进行一次归一化,在LSTM的输出后,会将隐藏层的结果进行一层线性映射,然后使用逆归一化,这样得到结果会比较接近我们的指数。

小结

LSTM模型的具体训练步骤如下:

1.LSTM 单元的输入包含当前时刻的输入

  1. v
  2. e
  3. c
  4. x
  5. t
  6. vec{x_t}
  7. vecxt​、上一时刻的输出状态
  8. h
  9. t
  10. 1
  11. h_{t-1}
  12. ht1​以及上一时刻的单元状态
  13. c
  14. t
  15. 1
  16. c_{t-1}
  17. ct1​。在进行运算第一层LSTM单元时,我们会手动初始化
  18. h
  19. 0
  20. h_0
  21. h0​、
  22. c
  23. 0
  24. c_0
  25. c0​,而在后面的LSTM的单元中
  26. h
  27. t
  28. 1
  29. h_{t-1}
  30. ht1​和
  31. c
  32. t
  33. 1
  34. c_{t-1}
  35. ct1​,都可以由上一次的LSTM单元获得。其中,
  36. x
  37. t
  38. R
  39. m
  40. ×
  41. 1
  42. \vec{x_t} \in \mathbb{R}^{m \times 1}
  43. xt​​∈Rm×1
  44. m
  45. m
  46. m是输入特征的维度,
  47. h
  48. t
  49. 1
  50. h_{t-1}
  51. ht1​上一时刻的输出状态,形状为
  52. h
  53. t
  54. 1
  55. R
  56. d
  57. ×
  58. 1
  59. h_{t-1} \in \mathbb{R}^{d \times 1}
  60. ht1​∈Rd×1
  61. d
  62. d
  63. dLSTM单元的隐藏状态大小,
  64. c
  65. t
  66. 1
  67. c_{t-1}
  68. ct1​是上一时刻的单元状态,形状为
  69. c
  70. t
  71. 1
  72. R
  73. d
  74. ×
  75. 1
  76. c_{t-1} \in \mathbb{R}^{d \times 1}
  77. ct1​∈Rd×1

我们通常会把

  1. h
  2. t
  3. 1
  4. h_{t-1}
  5. ht1​和
  6. x
  7. t
  8. \vec{x_t}
  9. xt​​拼在一起形成更长的向量
  10. y
  11. t
  12. \vec{y_t}
  13. yt​​,我们通常竖着拼,即
  14. y
  15. t
  16. R
  17. (
  18. d
  19. +
  20. m
  21. )
  22. ×
  23. 1
  24. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1}
  25. yt​​∈R(d+m1 ,然后
  26. y
  27. t
  28. \vec{y_t}
  29. yt​​会传入各个门。
  30. y
  31. t
  32. =
  33. [
  34. h
  35. t
  36. 1
  37. ;
  38. x
  39. t
  40. ]
  41. =
  42. [
  43. h
  44. t
  45. 1
  46. x
  47. t
  48. ]
  49. \vec{y_t} = [h_{t-1};\vec{x_t}] = \left[{\begin{matrix}h_{t-1} \\ \vec{x_t} \end{matrix}}\right]
  50. yt​​=[ht1​;xt​​]=[ht1xt​​​]

2.随后是计算各个门的输出,各个门的输入是

  1. y
  2. t
  3. \vec{y_t}
  4. yt​​。我们将
  5. y
  6. t
  7. \vec{y_t}
  8. yt​​与门中的权重矩阵
  9. W
  10. W
  11. W相乘再加上置偏值
  12. b
  13. b
  14. b,得到中间结果
  15. M
  16. M
  17. M。然后对
  18. M
  19. M
  20. MSigmoid,得到门的输出
  21. g
  22. t
  23. g_t
  24. gt​,其形状与单元状态
  25. c
  26. t
  27. c_t
  28. ct​相同,即
  29. g
  30. t
  31. R
  32. d
  33. ×
  34. 1
  35. g_t \in \mathbb{R}^{d \times 1}
  36. gt​∈Rd×1
  37. f
  38. t
  39. =
  40. σ
  41. (
  42. W
  43. f
  44. y
  45. t
  46. +
  47. b
  48. f
  49. )
  50. =
  51. 1
  52. 1
  53. +
  54. e
  55. (
  56. W
  57. f
  58. y
  59. t
  60. +
  61. b
  62. f
  63. )
  64. f_t = \sigma(W_f\vec{y_t}' + b_f) = \frac{1}{1 + e^{-(W_f\vec{y_t} + b_f)}}
  65. ft​=σ(Wf​yt​​′+bf​)=1+e−(Wf​yt​​+bf​)1​
  66. i
  67. t
  68. =
  69. σ
  70. (
  71. W
  72. i
  73. y
  74. t
  75. +
  76. b
  77. i
  78. )
  79. =
  80. 1
  81. 1
  82. +
  83. e
  84. (
  85. W
  86. i
  87. y
  88. t
  89. +
  90. b
  91. i
  92. )
  93. i_t = \sigma(W_i\vec{y_t} + b_i) = \frac{1}{1 + e^{-(W_i\vec{y_t} + b_i)}}
  94. it​=σ(Wi​yt​​+bi​)=1+e−(Wi​yt​​+bi​)1​
  95. o
  96. t
  97. =
  98. σ
  99. (
  100. W
  101. o
  102. y
  103. t
  104. +
  105. b
  106. o
  107. )
  108. =
  109. 1
  110. 1
  111. +
  112. e
  113. (
  114. W
  115. f
  116. y
  117. t
  118. +
  119. b
  120. o
  121. )
  122. o_t = \sigma(W_o\vec{y_t} + b_o) = \frac{1}{1 + e^{-(W_f\vec{y_t} + b_o)}}
  123. ot​=σ(Wo​yt​​+bo​)=1+e−(Wf​yt​​+bo​)1​

其中,

  1. y
  2. t
  3. R
  4. (
  5. d
  6. +
  7. m
  8. )
  9. ×
  10. 1
  11. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1}
  12. yt​​∈R(d+m1
  13. W
  14. f
  15. W
  16. i
  17. W
  18. o
  19. R
  20. d
  21. ×
  22. (
  23. d
  24. +
  25. m
  26. )
  27. W_fW_iW_o \in \mathbb{R}^{d \times (d + m)}
  28. Wf​、Wi​、Wo​∈Rd×(d+m),
  29. b
  30. f
  31. b
  32. i
  33. b
  34. o
  35. R
  36. d
  37. ×
  38. 1
  39. b_fb_ib_o \in \mathbb{R}^{d \times 1}
  40. bf​、bi​、bo​∈Rd×1
  41. f
  42. t
  43. i
  44. t
  45. o
  46. t
  47. R
  48. d
  49. ×
  50. 1
  51. f_ti_to_t \in \mathbb{R}^{d \times 1}
  52. ft​、it​、ot​∈Rd×1

3.计算当前输入单元状态

  1. c
  2. t
  3. ~
  4. \tilde{c_t}
  5. ct​~​的值,表示当前输入要保留多少内容到记忆中。我们将
  6. y
  7. t
  8. \vec{y_t}
  9. yt​​与当前时刻状态单元的权重矩阵
  10. W
  11. c
  12. W_c
  13. Wc​相乘再加上置偏值
  14. b
  15. c
  16. b_c
  17. bc​,得到中间结果
  18. M
  19. c
  20. M_c
  21. Mc​,然后对
  22. M
  23. c
  24. M_c
  25. Mc​取tanh,得到输出
  26. c
  27. t
  28. ~
  29. \tilde{c_t}
  30. ct​~​。
  31. c
  32. t
  33. ~
  34. =
  35. tanh
  36. (
  37. W
  38. c
  39. y
  40. t
  41. +
  42. b
  43. c
  44. )
  45. =
  46. e
  47. (
  48. W
  49. c
  50. y
  51. t
  52. +
  53. b
  54. c
  55. )
  56. e
  57. (
  58. W
  59. c
  60. y
  61. t
  62. +
  63. b
  64. c
  65. )
  66. e
  67. (
  68. W
  69. c
  70. y
  71. t
  72. +
  73. b
  74. c
  75. )
  76. +
  77. e
  78. (
  79. W
  80. c
  81. y
  82. t
  83. +
  84. b
  85. c
  86. )
  87. \tilde{c_t} = \text{tanh}(W_c\vec{y_t} + b_c) = \frac{e^{(W_c\vec{y_t} + b_c)}-e^{-(W_c\vec{y_t} + b_c)}}{e^{(W_c\vec{y_t} + b_c)}+e^{-(W_c\vec{y_t} + b_c)}}
  88. ct​~​=tanh(Wcyt​​+bc​)=e(Wcyt​​+bc​)+e−(Wcyt​​+bc​)e(Wcyt​​+bc​)−e−(Wcyt​​+bc​)​

其中,

  1. y
  2. t
  3. R
  4. (
  5. d
  6. +
  7. m
  8. )
  9. ×
  10. 1
  11. \vec{y_t} \in \mathbb{R}^{(d + m) \times 1}
  12. yt​​∈R(d+m1
  13. W
  14. c
  15. R
  16. d
  17. ×
  18. (
  19. d
  20. +
  21. m
  22. )
  23. W_c \in \mathbb{R}^{d \times (d + m)}
  24. Wc​∈Rd×(d+m),
  25. b
  26. c
  27. R
  28. d
  29. ×
  30. 1
  31. b_c \in \mathbb{R}^{d \times 1}
  32. bc​∈Rd×1
  33. c
  34. t
  35. ~
  36. R
  37. d
  38. ×
  39. 1
  40. \tilde{c_t} \in \mathbb{R}^{d \times 1}
  41. ct​~​∈Rd×1

4.接下来我们进行当前时刻单元状态

  1. c
  2. t
  3. c_t
  4. ct​的计算。我们使用遗忘门和输入门得到的结果
  5. f
  6. t
  7. f_t
  8. ft​、
  9. i
  10. t
  11. i_t
  12. it​和上一时刻单元状态
  13. c
  14. t
  15. 1
  16. c_{t-1}
  17. ct1​来计算当前时刻单元状态
  18. c
  19. t
  20. c_t
  21. ct​。我们分别将
  22. f
  23. t
  24. f_t
  25. ft​、
  26. c
  27. t
  28. 1
  29. c_{t-1}
  30. ct1​按元素相乘,
  31. i
  32. t
  33. i_t
  34. it​和
  35. c
  36. t
  37. ~
  38. \tilde{c_t}
  39. ct​~​按元素相乘,然后再将两者相加得到我们的但钱时刻单元状态
  40. c
  41. t
  42. c_t
  43. ct​。
  44. c
  45. t
  46. =
  47. f
  48. t
  49. c
  50. t
  51. 1
  52. +
  53. i
  54. t
  55. c
  56. t
  57. ~
  58. c_t = f_t \circ c_{t-1} + i_t \circ \tilde{c_t}
  59. ct​=ft​∘ct1​+it​∘ct​~​

其中,

  1. f
  2. t
  3. R
  4. d
  5. ×
  6. 1
  7. f_t \in \mathbb{R}^{d \times 1}
  8. ft​∈Rd×1时遗忘门输出,
  9. i
  10. t
  11. R
  12. d
  13. ×
  14. 1
  15. i_t \in \mathbb{R}^{d \times 1}
  16. it​∈Rd×1是输入门输出,
  17. c
  18. t
  19. ~
  20. R
  21. d
  22. ×
  23. 1
  24. \tilde{c_{t}} \in \mathbb{R}^{d \times 1}
  25. ct​~​∈Rd×1是当前输入状态单元,
  26. c
  27. t
  28. 1
  29. R
  30. d
  31. ×
  32. 1
  33. c_{t-1} \in \mathbb{R}^{d \times 1}
  34. ct1​∈Rd×1 是上一时刻状态单元,
  35. \circ
  36. ∘表示 **按元素乘**。

5.最后模型的输出是

  1. h
  2. t
  3. h_t
  4. ht​和当前时刻的单元状态
  5. c
  6. t
  7. c_t
  8. ct​,而
  9. h
  10. t
  11. h_t
  12. ht​由当前时刻的单元状态
  13. c
  14. t
  15. c_t
  16. ct​​和输出门的输出
  17. o
  18. t
  19. o_t
  20. ot​确定。我们将当前时刻的单元状态
  21. c
  22. t
  23. c_t
  24. ct​取 tanh得到
  25. d
  26. t
  27. d_t
  28. dt​,然后将
  29. d
  30. t
  31. d_t
  32. dt
  33. o
  34. t
  35. o_t
  36. ot​按元素相乘得到最后的
  37. h
  38. t
  39. h_t
  40. ht​。
  41. h
  42. t
  43. =
  44. o
  45. t
  46. d
  47. t
  48. =
  49. o
  50. t
  51. tanh
  52. (
  53. c
  54. t
  55. )
  56. =
  57. e
  58. c
  59. t
  60. e
  61. c
  62. t
  63. e
  64. c
  65. t
  66. +
  67. e
  68. c
  69. t
  70. h_t = o_t \circ d_t = o_t \circ \text{tanh}(c_t) = \frac{e^{c_t}-e^{-c_t}}{e^{c_t}+e^{-c_t}}
  71. ht​=ot​∘dt​=ot​∘tanh(ct​)=ect​+ectect​−ect​​

其中

  1. h
  2. t
  3. R
  4. d
  5. ×
  6. 1
  7. h_t \in \mathbb{R}^{d \times 1}
  8. ht​∈Rd×1 为当前层隐藏状态,
  9. o
  10. t
  11. R
  12. d
  13. ×
  14. 1
  15. o_t \in \mathbb{R}^{d \times 1}
  16. ot​∈Rd×1为输出门的输出,
  17. c
  18. t
  19. R
  20. d
  21. ×
  22. 1
  23. c_t \in \mathbb{R}^{d \times 1}
  24. ct​∈Rd×1为当前时刻状态单元。
  1. import torch
  2. import torch.nn as nn
  3. import numpy as np
  4. import pandas as pd
  5. import matplotlib.pyplot as plt
  6. from sklearn.preprocessing import MinMaxScaler
  7. # 读取数据
  8. df = pd.read_csv('sh_data.csv')
  9. df = df.iloc[-30:,[2,5,3,4]]
  10. df1 = df[25:28].reset_index(drop=True)
  11. df2 = df1.reset_index(drop=True)
  12. data = df[['open','close','high','low']].values.astype(float)# 标准化数据
  13. scaler = MinMaxScaler(feature_range=(0,1))
  14. data = scaler.fit_transform(data)# 创建时间序列数据defcreate_sequences(data, time_step=1):
  15. X, y =[],[]for i inrange(len(data)- time_step):
  16. X.append(data[i:(i + time_step)])
  17. y.append(data[i + time_step])return np.array(X), np.array(y)
  18. time_step =2# 时间步长设置为2天
  19. X, y = create_sequences(data, time_step)# 转换为PyTorch张量
  20. X = torch.FloatTensor(X)
  21. y = torch.FloatTensor(y)classLSTM(nn.Module):def__init__(self, input_size, hidden_layer_size, output_size):super(LSTM, self).__init__()
  22. self.hidden_layer_size = hidden_layer_size
  23. self.lstm = nn.LSTM(input_size, hidden_layer_size)
  24. self.linear = nn.Linear(hidden_layer_size, output_size)
  25. self.hidden_cell =(torch.zeros(1,1, self.hidden_layer_size),
  26. torch.zeros(1,1, self.hidden_layer_size))defforward(self, input_seq):
  27. lstm_out, self.hidden_cell = self.lstm(input_seq.view(len(input_seq),1,-1), self.hidden_cell)
  28. predictions = self.linear(lstm_out.view(len(input_seq),-1))return predictions[-1]
  29. input_size =4# 输入特征数量
  30. hidden_layer_size =4
  31. output_size =4# 输出特征数量
  32. model = LSTM(input_size=input_size, hidden_layer_size=hidden_layer_size, output_size=output_size)
  33. loss_function = nn.MSELoss()
  34. optimizer = torch.optim.Adam(model.parameters(), lr=0.001)# epochs = 1# for i in range(epochs):# for seq, labels in zip(X, y):# optimizer.zero_grad()# model.hidden_cell = (torch.zeros(1, 1, model.hidden_layer_size),# torch.zeros(1, 1, model.hidden_layer_size))# y_pred = model(seq)# single_loss = loss_function(y_pred, labels)# single_loss.backward()# optimizer.step()# if i % 10 == 0:# print(f'epoch: {i:3} loss: {single_loss.item():10.8f}')# 只进行一次训练
  35. seq, labels = X[0], y[0]
  36. optimizer.zero_grad()
  37. model.hidden_cell =(torch.zeros(1,1, model.hidden_layer_size),
  38. torch.zeros(1,1, model.hidden_layer_size))
  39. y_pred = model(seq)
  40. single_loss = loss_function(y_pred, labels)
  41. single_loss.backward()
  42. optimizer.step()print(f'Single training loss: {single_loss.item():10.8f}')
  43. model.eval()# 预测下一天的四个特征with torch.no_grad():
  44. seq = torch.FloatTensor(data[-time_step:])
  45. model.hidden_cell =(torch.zeros(1,1, model.hidden_layer_size),
  46. torch.zeros(1,1, model.hidden_layer_size))
  47. next_day = model(seq).numpy()# 将预测结果逆归一化
  48. next_day = scaler.inverse_transform(next_day.reshape(-1, output_size))print(f'Predicted features for the next day: open={next_day[0][0]}, close={next_day[0][1]}, high={next_day[0][2]}, low={next_day[0][3]}')# 获取训练集的预测值
  49. train_predict =[]for seq in X:with torch.no_grad():
  50. model.hidden_cell =(torch.zeros(1,1, model.hidden_layer_size),
  51. torch.zeros(1,1, model.hidden_layer_size))
  52. train_predict.append(model(seq).numpy())# 将预测结果逆归一化
  53. train_predict = scaler.inverse_transform(np.array(train_predict).reshape(-1, output_size))
  54. actual = scaler.inverse_transform(data)# 绘制图形
  55. plt.figure(figsize=(10,6))for i, col inenumerate(['open','close','high','low']):
  56. plt.subplot(2,2, i+1)
  57. plt.plot(actual[:, i], label=f'Actual {col}')
  58. plt.plot(range(time_step, time_step +len(train_predict)), train_predict[:, i], label=f'Train Predict {col}')
  59. plt.legend()
  60. plt.tight_layout()
  61. plt.show()
标签: lstm 人工智能 rnn

本文转载自: https://blog.csdn.net/weixin_44555174/article/details/140758757
版权归原作者 --fancy 所有, 如有侵权,请联系我们删除。

“LSTM模型计算详解”的评论:

还没有评论