0


小白也能读懂的ConvLSTM!(开源pytorch代码)

ConvLSTM

仅需要网络源码的可以直接跳到末尾即可

1. 算法简介与应用场景

ConvLSTM(卷积长短期记忆网络)是一种结合了卷积神经网络(CNN)和长短期记忆网络(LSTM)优势的深度学习模型。它主要用于处理时空数据,特别适用于需要考虑空间特征和时间依赖关系的任务,如气象预测、视频分析、交通流量预测等。

在气象预测中,ConvLSTM可以根据过去的气象数据(如降水、温度等)预测未来的天气情况。在视频分析中,它可以帮助识别视频中的活动或事件,利用时间序列的连续性和空间信息进行更准确的分析。

2. 算法原理

2.1 LSTM基础

在介绍ConvLSTM之前,先让我们来回归一下什么是长短期记忆网络(LSTM)。LSTM是一种特殊的循环神经网络(RNN),它通过引入门控机制解决了传统RNN在长序列训练中面临的梯度消失和爆炸问题。LSTM单元主要包含三个门:输入门、遗忘门和输出门。这些门控制着信息在单元中的流动,从而有效地记住或遗忘信息。

LSTM的核心公式如下:

  • 遗忘门: f t = σ ( W f ⋅ [ h t − 1 , x t ] + b f ) f_t = \sigma(W_f \cdot [h_{t-1}, x_t] + b_f) ft​=σ(Wf​⋅[ht−1​,xt​]+bf​)
  • 输入门: i t = σ ( W i ⋅ [ h t − 1 , x t ] + b i ) i_t = \sigma(W_i \cdot [h_{t-1}, x_t] + b_i) it​=σ(Wi​⋅[ht−1​,xt​]+bi​) C ~ t = tanh ⁡ ( W C ⋅ [ h t − 1 , x t ] + b C ) \tilde{C}t = \tanh(W_C \cdot [h{t-1}, x_t] + b_C) C~t​=tanh(WC​⋅[ht−1​,xt​]+bC​)
  • 单元状态更新: C t = f t ∗ C t − 1 + i t ∗ C ~ t C_t = f_t \ast C_{t-1} + i_t \ast \tilde{C}_t Ct​=ft​∗Ct−1​+it​∗C~t​
  • 输出门: o t = σ ( W o ⋅ [ h t − 1 , x t ] + b o ) o_t = \sigma(W_o \cdot [h_{t-1}, x_t] + b_o) ot​=σ(Wo​⋅[ht−1​,xt​]+bo​) h t = o t ∗ tanh ⁡ ( C t ) h_t = o_t \ast \tanh(C_t) ht​=ot​∗tanh(Ct​)

这里,

  1. C
  2. t
  3. C_t
  4. Ct 是当前的单元状态,
  5. h
  6. t
  7. h_t
  8. ht 是当前的隐藏状态,
  9. x
  10. t
  11. x_t
  12. xt 是当前的输入。

2.2 ConvLSTM原理

ConvLSTM在LSTM的基础上引入了卷积操作。传统的LSTM使用全连接层处理输入数据,而ConvLSTM则采用卷积层来处理空间数据。这样,ConvLSTM能够更好地捕捉输入数据中的空间特征。
在这里插入图片描述

2.2.1 ConvLSTM的结构

ConvLSTM的单元结构与LSTM非常相似,但是在每个门的计算中使用了卷积操作。具体来说,ConvLSTM的每个门的公式可以表示为:

  1. i
  2. t
  3. =
  4. σ
  5. (
  6. W
  7. x
  8. i
  9. X
  10. t
  11. +
  12. W
  13. h
  14. i
  15. H
  16. t
  17. 1
  18. +
  19. W
  20. c
  21. i
  22. C
  23. t
  24. 1
  25. +
  26. b
  27. i
  28. )
  29. i_t = \sigma (W_{xi} * X_t + W_{hi} * H_{t-1} + W_{ci} \circ C_{t-1} + b_i)
  30. it​=σ(Wxi​∗Xt​+Whi​∗Ht1​+Wci​∘Ct1​+bi​)
  31. f
  32. t
  33. =
  34. σ
  35. (
  36. W
  37. x
  38. f
  39. X
  40. t
  41. +
  42. W
  43. h
  44. f
  45. H
  46. t
  47. 1
  48. +
  49. W
  50. c
  51. f
  52. C
  53. t
  54. 1
  55. +
  56. b
  57. f
  58. )
  59. f_t = \sigma (W_{xf} * X_t + W_{hf} * H_{t-1} + W_{cf} \circ C_{t-1} + b_f)
  60. ft​=σ(Wxf​∗Xt​+Whf​∗Ht1​+Wcf​∘Ct1​+bf​)
  61. C
  62. t
  63. =
  64. f
  65. t
  66. C
  67. t
  68. 1
  69. +
  70. i
  71. t
  72. t
  73. a
  74. n
  75. h
  76. (
  77. W
  78. x
  79. c
  80. X
  81. t
  82. +
  83. W
  84. h
  85. c
  86. H
  87. t
  88. 1
  89. +
  90. b
  91. c
  92. )
  93. C_t = f_t \circ C_{t-1} + i_t \circ tanh(W_{xc} * X_t + W_{hc} * H_{t-1} + b_c)
  94. Ct​=ft​∘Ct1​+it​∘tanh(Wxc​∗Xt​+Whc​∗Ht1​+bc​)
  95. o
  96. t
  97. =
  98. σ
  99. (
  100. W
  101. x
  102. o
  103. X
  104. t
  105. +
  106. W
  107. h
  108. o
  109. H
  110. t
  111. 1
  112. +
  113. W
  114. c
  115. o
  116. C
  117. t
  118. +
  119. b
  120. o
  121. )
  122. o_t = \sigma (W_{xo} * X_t + W_{ho} * H_{t-1} + W_{co} \circ C_t + b_o)
  123. ot​=σ(Wxo​∗Xt​+Who​∗Ht1​+Wco​∘Ct​+bo​)
  124. H
  125. t
  126. =
  127. o
  128. t
  129. t
  130. a
  131. n
  132. h
  133. (
  134. C
  135. t
  136. )
  137. H_t = o_t \circ tanh(C_t)
  138. Ht​=ot​∘tanh(Ct​)

这里的 所有

  1. W
  2. W
  3. W都是是卷积权重,
  4. b
  5. b
  6. b是偏置项,
  7. σ
  8. \sigma
  9. σ sigmoid 函数,
  10. tanh
  11. \tanh
  12. tanh 是双曲正切函数。。

在这里插入图片描述

2.2.2 卷积操作的优点
  1. 空间特征提取:卷积操作能够有效提取输入数据中的空间特征。对于图像数据,卷积操作可以捕捉局部特征,例如边缘、纹理等,这在时间序列数据中同样适用。
  2. 参数共享:卷积操作通过使用相同的卷积核在不同位置计算特征,从而减少了模型参数的数量,降低了计算复杂度。
  3. 平移不变性:卷积网络对输入数据的平移具有不变性,即相同的特征在不同位置都会被检测到,这对于时空序列数据来说是非常重要的。

2.3 LSTM与ConvLSTM的对比分析

特性LSTMConvLSTM输入类型一维序列三维数据(时序的图像数据)处理方式全连接层卷积操作空间特征捕捉较弱较强应用场景自然语言处理、时间序列预测图像序列预测、视频分析

2.4 ConvLSTM的应用

ConvLSTM在多个领域中表现出色,特别适合处理具有时空特征的数据。以下是一些主要的应用场景:

  • 气象预测:利用历史气象数据(如温度、湿度、降水等)来预测未来的天气情况。
  • 视频分析:对视频中的动态场景进行建模,识别和预测视频中的活动。
  • 交通流量预测:基于历史交通数据预测未来的交通流量,帮助城市交通管理。
  • 医学影像分析:分析医学影像序列(如CT、MRI)中的变化,辅助疾病诊断。

3. PyTorch代码

以下是ConvLSTM的完整代码,可以直接拿来用:

  1. import torch.nn as nn
  2. import torch
  3. classConvLSTMCell(nn.Module):def__init__(self, input_dim, hidden_dim, kernel_size, bias):"""
  4. 初始化卷积 LSTM 单元。
  5. 参数:
  6. ----------
  7. input_dim: int
  8. 输入张量的通道数。
  9. hidden_dim: int
  10. 隐藏状态的通道数。
  11. kernel_size: (int, int)
  12. 卷积核的大小。
  13. bias: bool
  14. 是否添加偏置项。
  15. """super(ConvLSTMCell, self).__init__()
  16. self.input_dim = input_dim
  17. self.hidden_dim = hidden_dim
  18. self.kernel_size = kernel_size
  19. # 计算填充大小以保持输入和输出尺寸一致
  20. self.padding = kernel_size[0]//2, kernel_size[1]//2
  21. self.bias = bias
  22. # 定义卷积层,输入是输入维度加上隐藏维度,输出是4倍的隐藏维度(对应i, f, o, g)
  23. self.conv = nn.Conv2d(in_channels=self.input_dim + self.hidden_dim,
  24. out_channels=4* self.hidden_dim,
  25. kernel_size=self.kernel_size,
  26. padding=self.padding,
  27. bias=self.bias)defforward(self, input_tensor, cur_state):
  28. h_cur, c_cur = cur_state
  29. # 沿着通道轴进行拼接
  30. combined = torch.cat([input_tensor, h_cur], dim=1)
  31. combined_conv = self.conv(combined)# 将输出分割成四个部分,分别对应输入门、遗忘门、输出门和候选单元状态
  32. cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=1)
  33. i = torch.sigmoid(cc_i)
  34. f = torch.sigmoid(cc_f)
  35. o = torch.sigmoid(cc_o)
  36. g = torch.tanh(cc_g)# 更新单元状态
  37. c_next = f * c_cur + i * g
  38. # 更新隐藏状态
  39. h_next = o * torch.tanh(c_next)return h_next, c_next
  40. definit_hidden(self, batch_size, image_size):
  41. height, width = image_size
  42. # 初始化隐藏状态和单元状态为零return(torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device),
  43. torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device))classConvLSTM(nn.Module):"""
  44. 卷积 LSTM 层。
  45. 参数:
  46. ----------
  47. input_dim: 输入通道数
  48. hidden_dim: 隐藏通道数
  49. kernel_size: 卷积核大小
  50. num_layers: LSTM 层的数量
  51. batch_first: 批次是否在第一维
  52. bias: 卷积中是否有偏置项
  53. return_all_layers: 是否返回所有层的计算结果
  54. 输入:
  55. ------
  56. 一个形状为 B, T, C, H, W 或者 T, B, C, H, W 的张量
  57. 输出:
  58. ------
  59. 元组包含两个列表(长度为 num_layers 或者长度为 1 如果 return_all_layers 为 False):
  60. 0 - layer_output_list 是长度为 T 的每个输出的列表
  61. 1 - last_state_list 是最后的状态列表,其中每个元素是一个 (h, c) 对应隐藏状态和记忆状态
  62. 示例:
  63. >>> x = torch.rand((32, 10, 64, 128, 128))
  64. >>> convlstm = ConvLSTM(64, 16, 3, 1, True, True, False)
  65. >>> _, last_states = convlstm(x)
  66. >>> h = last_states[0][0] # 0 表示层索引,0 表示 h 索引
  67. """def__init__(self, input_dim, hidden_dim, kernel_size, num_layers,
  68. batch_first=False, bias=True, return_all_layers=False):super(ConvLSTM, self).__init__()# 检查 kernel_size 的一致性
  69. self._check_kernel_size_consistency(kernel_size)# 确保 kernel_size hidden_dim 的长度与层数一致
  70. kernel_size = self._extend_for_multilayer(kernel_size, num_layers)
  71. hidden_dim = self._extend_for_multilayer(hidden_dim, num_layers)ifnotlen(kernel_size)==len(hidden_dim)== num_layers:raise ValueError('不一致的列表长度。')
  72. self.input_dim = input_dim
  73. self.hidden_dim = hidden_dim
  74. self.kernel_size = kernel_size
  75. self.num_layers = num_layers
  76. self.batch_first = batch_first
  77. self.bias = bias
  78. self.return_all_layers = return_all_layers
  79. # 创建 ConvLSTMCell 列表
  80. cell_list =[]for i inrange(0, self.num_layers):
  81. cur_input_dim = self.input_dim if i ==0else self.hidden_dim[i -1]
  82. cell_list.append(ConvLSTMCell(input_dim=cur_input_dim,
  83. hidden_dim=self.hidden_dim[i],
  84. kernel_size=self.kernel_size[i],
  85. bias=self.bias))
  86. self.cell_list = nn.ModuleList(cell_list)defforward(self, input_tensor, hidden_state=None):"""
  87. 前向传播函数。
  88. 参数:
  89. ----------
  90. input_tensor: 输入张量,形状为 (t, b, c, h, w) 或者 (b, t, c, h, w)
  91. hidden_state: 初始隐藏状态,默认为 None
  92. 返回:
  93. -------
  94. last_state_list, layer_output
  95. """ifnot self.batch_first:# 改变输入张量的顺序,如果 batch_first False
  96. input_tensor = input_tensor.permute(1,0,2,3,4)
  97. b, _, _, h, w = input_tensor.size()# 实现状态化的 ConvLSTMif hidden_state isnotNone:raise NotImplementedError()else:# 初始化隐藏状态
  98. hidden_state = self._init_hidden(batch_size=b,
  99. image_size=(h, w))
  100. layer_output_list =[]
  101. last_state_list =[]
  102. seq_len = input_tensor.size(1)
  103. cur_layer_input = input_tensor
  104. for layer_idx inrange(self.num_layers):
  105. h, c = hidden_state[layer_idx]
  106. output_inner =[]for t inrange(seq_len):# 在每个时间步上更新状态
  107. h, c = self.cell_list[layer_idx](input_tensor=cur_layer_input[:, t,:,:,:],
  108. cur_state=[h, c])
  109. output_inner.append(h)# 将输出堆叠起来
  110. layer_output = torch.stack(output_inner, dim=1)
  111. cur_layer_input = layer_output
  112. layer_output_list.append(layer_output)
  113. last_state_list.append([h, c])ifnot self.return_all_layers:# 如果不需要返回所有层,则只返回最后一层的输出和状态
  114. layer_output_list = layer_output_list[-1:]
  115. last_state_list = last_state_list[-1:]return layer_output_list, last_state_list
  116. def_init_hidden(self, batch_size, image_size):
  117. init_states =[]for i inrange(self.num_layers):# 初始化每一层的隐藏状态
  118. init_states.append(self.cell_list[i].init_hidden(batch_size, image_size))return init_states
  119. @staticmethoddef_check_kernel_size_consistency(kernel_size):ifnot(isinstance(kernel_size,tuple)or(isinstance(kernel_size,list)andall([isinstance(elem,tuple)for elem in kernel_size]))):raise ValueError('`kernel_size` 必须是 tuple 或者 list of tuples')@staticmethoddef_extend_for_multilayer(param, num_layers):ifnotisinstance(param,list):
  120. param =[param]* num_layers
  121. return param

参考文献

  1. [1]Shi, X., Chen, Z., Wang, H., Yeung, D. Y., Wong, W. K., & Woo, W. (2015). Convolutional LSTM Network: A Machine Learning [2]Approach for Precipitation Nowcasting. Advances in Neural Information Processing Systems, 28.
  2. [2]Hochreiter, S., & Schmidhuber, J. (1997). Long Short-Term Memory. Neural Computation, 9(8), 1735-1780.
  3. [3]Goodfellow, I., Bengio, Y., & Courville, A. (2016). Deep Learning. MIT Press.

本文转载自: https://blog.csdn.net/m0_59257547/article/details/140758429
版权归原作者 机器学习与优化算法 所有, 如有侵权,请联系我们删除。

“小白也能读懂的ConvLSTM!(开源pytorch代码)”的评论:

还没有评论