0


深入浅出一文图解Vision Mamba(ViM)

文章目录


引言:Mamba

2024年04月29日16:06:08,今天开始记录mamba模块的学习与使用过程。


第一章:环境安装

亲测,根据下文的安装步骤,

  1. 即可成功!

使用代码

  1. Vision Mamba

https://github.com/hustvl/Vim

  1. git clone https://github.com/hustvl/Vim.git

1.1安装教程

  1. 安装教程:下载好vision mamba后,根据下面的教程一步一步安装即可成功。

vision mamba 运行训练记录,解决bimamba_type错误

1.2问题总结

  1. 问题总结:遇见的问题可以参考这个链接,总结的比较全面。

Mamba 环境安装踩坑问题汇总及解决方法

1.3安装总结

关键就是下载

  1. causal_conv1d

  1. mamba_ssm

,最好是下载离线的

  1. whl

文件,然后再用

  1. pip

进行安装。值得注意的一点就是要用官方项目里的mamba_ssm替换安装在conda环境里的mamba_ssm。


第二章:即插即用模块

2.1模块一:Mamba Vision

Github:https://github.com/hustvl/Vim;
下载代码,配置好环境后,用下面的代码替换

  1. Vim/vim/models_mamba.py

,即可直接运行;

运行指令

  1. python models_mamba.py
代码:models_mamba.py
  1. # Copyright (c) 2015-present, Facebook, Inc.# All rights reserved.import torch
  2. import torch.nn as nn
  3. from functools import partial
  4. from torch import Tensor
  5. from typing import Optional
  6. from timm.models.vision_transformer import VisionTransformer, _cfg
  7. from timm.models.registry import register_model
  8. from timm.models.layers import trunc_normal_, lecun_normal_
  9. from timm.models.layers import DropPath, to_2tuple
  10. from timm.models.vision_transformer import _load_weights
  11. import math
  12. from collections import namedtuple
  13. from mamba_ssm.modules.mamba_simple import Mamba
  14. from mamba_ssm.utils.generation import GenerationMixin
  15. from mamba_ssm.utils.hf import load_config_hf, load_state_dict_hf
  16. from rope import *
  17. import random
  18. try:
  19. from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
  20. except ImportError:
  21. RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
  22. __all__ =['vim_tiny_patch16_224', 'vim_small_patch16_224', 'vim_base_patch16_224',
  23. 'vim_tiny_patch16_384', 'vim_small_patch16_384', 'vim_base_patch16_384',
  24. ]
  25. class PatchEmbed(nn.Module):
  26. """ 2D Image to Patch Embedding
  27. """
  28. def __init__(self, img_size=224, patch_size=16, stride=16, in_chans=3, embed_dim=768, norm_layer=None, flatten=True):
  29. super().__init__()
  30. img_size = to_2tuple(img_size)
  31. patch_size = to_2tuple(patch_size)
  32. self.img_size = img_size
  33. self.patch_size = patch_size
  34. self.grid_size =((img_size[0] - patch_size[0])// stride +1,(img_size[1] - patch_size[1])// stride +1)
  35. self.num_patches = self.grid_size[0] * self.grid_size[1]
  36. self.flatten = flatten
  37. self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride)
  38. self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
  39. def forward(self, x):
  40. B, C, H, W = x.shape
  41. assert H == self.img_size[0] and W == self.img_size[1], \
  42. f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
  43. x = self.proj(x)
  44. if self.flatten:
  45. x = x.flatten(2).transpose(1,2) # BCHW -> BNC
  46. x = self.norm(x)
  47. return x
  48. class Block(nn.Module):
  49. def __init__(
  50. self, dim, mixer_cls, norm_cls=nn.LayerNorm, fused_add_norm=False, residual_in_fp32=False,drop_path=0.,):
  51. """
  52. Simple block wrapping a mixer class with LayerNorm/RMSNorm and residual connection"
  53. This Block has a slightly different structure compared to a regular
  54. prenorm Transformer block.
  55. The standard block is: LN -> MHA/MLP -> Add.
  56. [Ref: https://arxiv.org/abs/2002.04745]
  57. Here we have: Add -> LN -> Mixer, returning both
  58. the hidden_states (output of the mixer) and the residual.
  59. This is purely for performance reasons, as we can fuse add and LayerNorm.
  60. The residual needs to be provided (except for the very first block).
  61. """
  62. super().__init__()
  63. self.residual_in_fp32 = residual_in_fp32
  64. self.fused_add_norm = fused_add_norm
  65. self.mixer = mixer_cls(dim)
  66. self.norm = norm_cls(dim)
  67. self.drop_path = DropPath(drop_path) if drop_path >0. else nn.Identity()
  68. if self.fused_add_norm:
  69. assert RMSNorm is not None, "RMSNorm import fails"
  70. assert isinstance(
  71. self.norm,(nn.LayerNorm, RMSNorm)), "Only LayerNorm and RMSNorm are supported for fused_add_norm"
  72. def forward(
  73. self, hidden_states: Tensor, residual: Optional[Tensor] = None, inference_params=None
  74. ):
  75. r"""Pass the input through the encoder layer.
  76. Args:
  77. hidden_states: the sequence to the encoder layer (required).
  78. residual: hidden_states = Mixer(LN(residual))"""
  79. if not self.fused_add_norm:
  80. if residual is None:
  81. residual = hidden_states
  82. else:
  83. residual = residual + self.drop_path(hidden_states)
  84. hidden_states = self.norm(residual.to(dtype=self.norm.weight.dtype))
  85. if self.residual_in_fp32:
  86. residual = residual.to(torch.float32)
  87. else:
  88. fused_add_norm_fn = rms_norm_fn if isinstance(self.norm, RMSNorm) else layer_norm_fn
  89. if residual is None:
  90. hidden_states, residual = fused_add_norm_fn(
  91. hidden_states,
  92. self.norm.weight,
  93. self.norm.bias,
  94. residual=residual,
  95. prenorm=True,
  96. residual_in_fp32=self.residual_in_fp32,
  97. eps=self.norm.eps,
  98. )
  99. else:
  100. hidden_states, residual = fused_add_norm_fn(
  101. self.drop_path(hidden_states),
  102. self.norm.weight,
  103. self.norm.bias,
  104. residual=residual,
  105. prenorm=True,
  106. residual_in_fp32=self.residual_in_fp32,
  107. eps=self.norm.eps,
  108. )
  109. hidden_states = self.mixer(hidden_states, inference_params=inference_params)
  110. return hidden_states, residual
  111. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  112. return self.mixer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
  113. def create_block(
  114. d_model,
  115. ssm_cfg=None,
  116. norm_epsilon=1e-5,
  117. drop_path=0.,
  118. rms_norm=False,
  119. residual_in_fp32=False,
  120. fused_add_norm=False,
  121. layer_idx=None,
  122. device=None,
  123. dtype=None,
  124. if_bimamba=False,
  125. bimamba_type="none",
  126. if_devide_out=False,
  127. init_layer_scale=None,
  128. ):
  129. if if_bimamba:
  130. bimamba_type = "v1"
  131. if ssm_cfg is None:
  132. ssm_cfg = {}
  133. factory_kwargs = {"device": device, "dtype": dtype}
  134. mixer_cls = partial(Mamba, layer_idx=layer_idx, bimamba_type=bimamba_type, if_devide_out=if_devide_out, init_layer_scale=init_layer_scale, **ssm_cfg, **factory_kwargs)
  135. norm_cls = partial(
  136. nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
  137. )
  138. block = Block(
  139. d_model,
  140. mixer_cls,
  141. norm_cls=norm_cls,
  142. drop_path=drop_path,
  143. fused_add_norm=fused_add_norm,
  144. residual_in_fp32=residual_in_fp32,
  145. )
  146. block.layer_idx = layer_idx
  147. return block
  148. # https://github.com/huggingface/transformers/blob/c28d04e9e252a1a099944e325685f14d242ecdcd/src/transformers/models/gpt2/modeling_gpt2.py#L454
  149. def _init_weights(
  150. module,
  151. n_layer,
  152. initializer_range=0.02, # Now only used for embedding layer.
  153. rescale_prenorm_residual=True,
  154. n_residuals_per_layer=1, # Change to 2 if we have MLP
  155. ):
  156. if isinstance(module, nn.Linear):
  157. if module.bias is not None:
  158. if not getattr(module.bias, "_no_reinit", False):
  159. nn.init.zeros_(module.bias)
  160. elif isinstance(module, nn.Embedding):
  161. nn.init.normal_(module.weight, std=initializer_range)
  162. if rescale_prenorm_residual:
  163. # Reinitialize selected weights subject to the OpenAI GPT-2 Paper Scheme:
  164. # > A modified initialization which accounts for the accumulation on the residual path with model depth. Scale
  165. # > the weights of residual layers at initialization by a factor of 1/√N where N is the # of residual layers.
  166. # > -- GPT-2 :: https://openai.com/blog/better-language-models/
  167. #
  168. # Reference (Megatron-LM): https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/model/gpt_model.py
  169. for name, p in module.named_parameters():
  170. if name in ["out_proj.weight", "fc2.weight"]:
  171. # Special Scaled Initialization --> There are 2 Layer Norms per Transformer Block
  172. # Following Pytorch init, except scale by 1/sqrt(2 * n_layer)
  173. # We need to reinit p since this code could be called multiple times
  174. # Having just p *= scale would repeatedly scale it down
  175. nn.init.kaiming_uniform_(p, a=math.sqrt(5))
  176. with torch.no_grad():
  177. p /= math.sqrt(n_residuals_per_layer * n_layer)
  178. def segm_init_weights(m):
  179. if isinstance(m, nn.Linear):
  180. trunc_normal_(m.weight, std=0.02)
  181. if isinstance(m, nn.Linear) and m.bias is not None:
  182. nn.init.constant_(m.bias, 0)
  183. elif isinstance(m, nn.Conv2d):
  184. # NOTE conv was left to pytorch default in my original init
  185. lecun_normal_(m.weight)
  186. if m.bias is not None:
  187. nn.init.zeros_(m.bias)
  188. elif isinstance(m, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
  189. nn.init.zeros_(m.bias)
  190. nn.init.ones_(m.weight)
  191. class VisionMamba(nn.Module):
  192. def __init__(self,
  193. img_size=224,
  194. patch_size=16,
  195. stride=16,
  196. depth=24,
  197. embed_dim=192,
  198. channels=3,
  199. num_classes=1000,
  200. ssm_cfg=None,
  201. drop_rate=0.,
  202. drop_path_rate=0.1,
  203. norm_epsilon: float = 1e-5,
  204. rms_norm: bool = False,
  205. initializer_cfg=None,
  206. fused_add_norm=False,
  207. residual_in_fp32=False,
  208. device=None,
  209. dtype=None,
  210. ft_seq_len=None,
  211. pt_hw_seq_len=14,
  212. if_bidirectional=False,
  213. final_pool_type='none',
  214. if_abs_pos_embed=False,
  215. if_rope=False,
  216. if_rope_residual=False,
  217. flip_img_sequences_ratio=-1.,
  218. if_bimamba=False,
  219. bimamba_type="none",
  220. if_cls_token=False,
  221. if_devide_out=False,
  222. init_layer_scale=None,
  223. use_double_cls_token=False,
  224. use_middle_cls_token=False,
  225. **kwargs):
  226. factory_kwargs = {"device": device, "dtype": dtype}
  227. # add factory_kwargs into kwargs
  228. kwargs.update(factory_kwargs)
  229. super().__init__()
  230. self.residual_in_fp32 = residual_in_fp32
  231. self.fused_add_norm = fused_add_norm
  232. self.if_bidirectional = if_bidirectional
  233. self.final_pool_type = final_pool_type
  234. self.if_abs_pos_embed = if_abs_pos_embed
  235. self.if_rope = if_rope
  236. self.if_rope_residual = if_rope_residual
  237. self.flip_img_sequences_ratio = flip_img_sequences_ratio
  238. self.if_cls_token = if_cls_token
  239. self.use_double_cls_token = use_double_cls_token
  240. self.use_middle_cls_token = use_middle_cls_token
  241. self.num_tokens = 1 if if_cls_token else 0
  242. # pretrain parameters
  243. self.num_classes = num_classes
  244. self.d_model = self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
  245. self.patch_embed = PatchEmbed(
  246. img_size=img_size, patch_size=patch_size, stride=stride, in_chans=channels, embed_dim=embed_dim)
  247. num_patches = self.patch_embed.num_patches
  248. if if_cls_token:
  249. if use_double_cls_token:
  250. self.cls_token_head = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
  251. self.cls_token_tail = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
  252. self.num_tokens = 2
  253. else:
  254. self.cls_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
  255. # self.num_tokens = 1
  256. if if_abs_pos_embed:
  257. self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, self.embed_dim))
  258. self.pos_drop = nn.Dropout(p=drop_rate)
  259. if if_rope:
  260. half_head_dim = embed_dim // 2
  261. hw_seq_len = img_size // patch_size
  262. self.rope = VisionRotaryEmbeddingFast(
  263. dim=half_head_dim,
  264. pt_seq_len=pt_hw_seq_len,
  265. ft_seq_len=hw_seq_len
  266. )
  267. self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
  268. # TODO: release this comment
  269. dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
  270. # import ipdb;ipdb.set_trace()
  271. inter_dpr = [0.0] + dpr
  272. self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0. else nn.Identity()
  273. # transformer blocks
  274. self.layers = nn.ModuleList(
  275. [
  276. create_block(
  277. embed_dim,
  278. ssm_cfg=ssm_cfg,
  279. norm_epsilon=norm_epsilon,
  280. rms_norm=rms_norm,
  281. residual_in_fp32=residual_in_fp32,
  282. fused_add_norm=fused_add_norm,
  283. layer_idx=i,
  284. if_bimamba=if_bimamba,
  285. bimamba_type=bimamba_type,
  286. drop_path=inter_dpr[i],
  287. if_devide_out=if_devide_out,
  288. init_layer_scale=init_layer_scale,
  289. **factory_kwargs,
  290. )
  291. for i in range(depth)
  292. ]
  293. )
  294. # output head
  295. self.norm_f = (nn.LayerNorm if not rms_norm else RMSNorm)(
  296. embed_dim, eps=norm_epsilon, **factory_kwargs
  297. )
  298. # self.pre_logits = nn.Identity()
  299. # original init
  300. self.patch_embed.apply(segm_init_weights)
  301. self.head.apply(segm_init_weights)
  302. if if_abs_pos_embed:
  303. trunc_normal_(self.pos_embed, std=.02)
  304. if if_cls_token:
  305. if use_double_cls_token:
  306. trunc_normal_(self.cls_token_head, std=.02)
  307. trunc_normal_(self.cls_token_tail, std=.02)
  308. else:
  309. trunc_normal_(self.cls_token, std=.02)
  310. # mamba init
  311. self.apply(
  312. partial(
  313. _init_weights,
  314. n_layer=depth,
  315. **(initializer_cfg if initializer_cfg is not None else {}),
  316. )
  317. )
  318. def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
  319. return {
  320. i: layer.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
  321. for i, layer in enumerate(self.layers)
  322. }
  323. @torch.jit.ignore
  324. def no_weight_decay(self):
  325. return {"pos_embed", "cls_token", "dist_token", "cls_token_head", "cls_token_tail"}
  326. @torch.jit.ignore()
  327. def load_pretrained(self, checkpoint_path, prefix=""):
  328. _load_weights(self, checkpoint_path, prefix)
  329. def forward_features(self, x, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
  330. # taken from https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/vision_transformer.py# with slight modifications to add the dist_token
  331. x = self.patch_embed(x)
  332. B, M, _ = x.shape
  333. if self.if_cls_token:
  334. if self.use_double_cls_token:
  335. cls_token_head = self.cls_token_head.expand(B, -1, -1)
  336. cls_token_tail = self.cls_token_tail.expand(B, -1, -1)
  337. token_position =[0, M + 1]
  338. x = torch.cat((cls_token_head, x, cls_token_tail), dim=1)
  339. M = x.shape[1]
  340. else:
  341. if self.use_middle_cls_token:
  342. cls_token = self.cls_token.expand(B,-1,-1)
  343. token_position = M //2
  344. # add cls token in the middle
  345. x = torch.cat((x[:,:token_position,:], cls_token, x[:, token_position:,:]), dim=1)
  346. elif if_random_cls_token_position:
  347. cls_token = self.cls_token.expand(B,-1,-1)
  348. token_position = random.randint(0, M)
  349. x = torch.cat((x[:,:token_position,:], cls_token, x[:, token_position:,:]), dim=1)
  350. print("token_position: ", token_position)
  351. else:
  352. cls_token = self.cls_token.expand(B,-1,-1) # stole cls_tokens impl from Phil Wang, thanks
  353. token_position =0
  354. x = torch.cat((cls_token, x), dim=1)
  355. M = x.shape[1]
  356. if self.if_abs_pos_embed:
  357. # if new_grid_size[0] == self.patch_embed.grid_size[0] and new_grid_size[1] == self.patch_embed.grid_size[1]:
  358. # x = x + self.pos_embed
  359. # else:
  360. # pos_embed = interpolate_pos_embed_online(
  361. # self.pos_embed, self.patch_embed.grid_size, new_grid_size,0
  362. # )
  363. x = x + self.pos_embed
  364. x = self.pos_drop(x)
  365. if if_random_token_rank:
  366. # 生成随机 shuffle 索引
  367. shuffle_indices = torch.randperm(M)
  368. if isinstance(token_position, list):
  369. print("original value: ", x[0, token_position[0],0], x[0, token_position[1],0])
  370. else:
  371. print("original value: ", x[0, token_position,0])
  372. print("original token_position: ", token_position)
  373. # 执行 shuffle
  374. x = x[:, shuffle_indices,:]
  375. if isinstance(token_position, list):
  376. # 找到 cls token 在 shuffle 之后的新位置
  377. new_token_position = [torch.where(shuffle_indices == token_position[i])[0].item() for i in range(len(token_position))]
  378. token_position = new_token_position
  379. else:
  380. # 找到 cls token 在 shuffle 之后的新位置
  381. token_position = torch.where(shuffle_indices == token_position)[0].item()if isinstance(token_position, list):
  382. print("new value: ", x[0, token_position[0], 0], x[0, token_position[1], 0])
  383. else:
  384. print("new value: ", x[0, token_position, 0])
  385. print("new token_position: ", token_position)
  386. if_flip_img_sequences = False
  387. if self.flip_img_sequences_ratio >0 and (self.flip_img_sequences_ratio - random.random())> 1e-5:
  388. x = x.flip([1])
  389. if_flip_img_sequences = True
  390. # mamba impl
  391. residual = None
  392. hidden_states = x
  393. if not self.if_bidirectional:
  394. forlayerin self.layers:
  395. if if_flip_img_sequences and self.if_rope:
  396. hidden_states = hidden_states.flip([1])if residual is not None:
  397. residual = residual.flip([1])# rope aboutif self.if_rope:
  398. hidden_states = self.rope(hidden_states)if residual is not None and self.if_rope_residual:
  399. residual = self.rope(residual)if if_flip_img_sequences and self.if_rope:
  400. hidden_states = hidden_states.flip([1])if residual is not None:
  401. residual = residual.flip([1])
  402. hidden_states, residual = layer(
  403. hidden_states, residual, inference_params=inference_params
  404. )
  405. else:
  406. # get two layers in a single for-loopforiin range(len(self.layers) // 2):
  407. if self.if_rope:
  408. hidden_states = self.rope(hidden_states)if residual is not None and self.if_rope_residual:
  409. residual = self.rope(residual)
  410. hidden_states_f, residual_f = self.layers[i * 2](
  411. hidden_states, residual, inference_params=inference_params
  412. )
  413. hidden_states_b, residual_b = self.layers[i * 2 + 1](
  414. hidden_states.flip([1]), None if residual == None else residual.flip([1]), inference_params=inference_params
  415. )
  416. hidden_states = hidden_states_f + hidden_states_b.flip([1])
  417. residual = residual_f + residual_b.flip([1])if not self.fused_add_norm:
  418. if residual is None:
  419. residual = hidden_states
  420. else:
  421. residual = residual + self.drop_path(hidden_states)
  422. hidden_states = self.norm_f(residual.to(dtype=self.norm_f.weight.dtype))
  423. else:
  424. # Set prenorm=False here since we don't need the residual
  425. fused_add_norm_fn = rms_norm_fn if isinstance(self.norm_f, RMSNorm)else layer_norm_fn
  426. hidden_states = fused_add_norm_fn(
  427. self.drop_path(hidden_states),
  428. self.norm_f.weight,
  429. self.norm_f.bias,
  430. eps=self.norm_f.eps,
  431. residual=residual,
  432. prenorm=False,
  433. residual_in_fp32=self.residual_in_fp32,
  434. )# return only cls token if it existsif self.if_cls_token:
  435. if self.use_double_cls_token:
  436. return(hidden_states[:, token_position[0], :] + hidden_states[:, token_position[1], :]) / 2
  437. else:
  438. if self.use_middle_cls_token:
  439. return hidden_states[:, token_position, :]elif if_random_cls_token_position:
  440. return hidden_states[:, token_position, :]
  441. else:
  442. return hidden_states[:, token_position, :]if self.final_pool_type =='none':return hidden_states[:, -1, :]elif self.final_pool_type =='mean':return hidden_states.mean(dim=1)elif self.final_pool_type =='max':return hidden_states
  443. elif self.final_pool_type =='all':return hidden_states
  444. else:
  445. raise NotImplementedError
  446. def forward(self, x, return_features=False, inference_params=None, if_random_cls_token_position=False, if_random_token_rank=False):
  447. x = self.forward_features(x, inference_params, if_random_cls_token_position=if_random_cls_token_position, if_random_token_rank=if_random_token_rank)if return_features:
  448. return x
  449. x = self.head(x)if self.final_pool_type =='max':
  450. x = x.max(dim=1)[0]return x
  451. @register_model
  452. def vim_tiny_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
  453. model = VisionMamba(patch_size=16, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  454. model.default_cfg = _cfg()if pretrained:
  455. checkpoint = torch.hub.load_state_dict_from_url(url="to.do",
  456. map_location="cpu", check_hash=True
  457. )
  458. model.load_state_dict(checkpoint["model"])return model
  459. @register_model
  460. def vim_tiny_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
  461. model = VisionMamba(patch_size=16, stride=8, embed_dim=192, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  462. model.default_cfg = _cfg()if pretrained:
  463. checkpoint = torch.hub.load_state_dict_from_url(url="to.do",
  464. map_location="cpu", check_hash=True
  465. )
  466. model.load_state_dict(checkpoint["model"])return model
  467. @register_model
  468. def vim_small_patch16_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
  469. model = VisionMamba(patch_size=16, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  470. model.default_cfg = _cfg()if pretrained:
  471. checkpoint = torch.hub.load_state_dict_from_url(url="to.do",
  472. map_location="cpu", check_hash=True
  473. )
  474. model.load_state_dict(checkpoint["model"])return model
  475. @register_model
  476. def vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False, **kwargs):
  477. model = VisionMamba(patch_size=16, stride=8, embed_dim=384, depth=24, rms_norm=True, residual_in_fp32=True, fused_add_norm=True, final_pool_type='mean', if_abs_pos_embed=True, if_rope=False, if_rope_residual=False, bimamba_type="v2", if_cls_token=True, if_devide_out=True, use_middle_cls_token=True, **kwargs)
  478. model.default_cfg = _cfg()if pretrained:
  479. checkpoint = torch.hub.load_state_dict_from_url(url="to.do",
  480. map_location="cpu", check_hash=True
  481. )
  482. model.load_state_dict(checkpoint["model"])return model
  483. if __name__ =='__main__':# cuda or cpu
  484. device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
  485. print(device)# 实例化模型得到分类结果
  486. inputs = torch.randn(1, 3, 224, 224).to(device)
  487. model = vim_small_patch16_stride8_224_bimambav2_final_pool_mean_abs_pos_embed_with_midclstok_div2(pretrained=False).to(device)
  488. print(model)
  489. outputs = model(inputs)
  490. print(outputs.shape)# 实例化mamba模块,输入输出特征维度不变 B C H W
  491. x = torch.rand(10, 16, 64, 128).to(device)
  492. B, C, H, W = x.shape
  493. print("输入特征维度:", x.shape)
  494. x = x.view(B, C, H * W).permute(0, 2, 1)
  495. print("维度变换:", x.shape)
  496. mamba = create_block(d_model=C).to(device)# mamba模型代码中返回的是一个元组:hidden_states, residual
  497. hidden_states, residual = mamba(x)
  498. x = hidden_states.permute(0, 2, 1).view(B, C, H, W)
  499. print("输出特征维度:", x.shape)
运行结果

在这里插入图片描述


2.2模块二:MambaIR

B站UP主:@箫张跋扈

视频地址:Mamba Back!一种来自于Mamba领域的即插即用模块(TimeMachine),用于时间序列任务!

下载好代码后,把下面的代码放到

  1. MambaIR.py

文件中,然后再运行即可得到结果。

代码:MambaIR
  1. # Code Implementation of the MambaIR Modelimport warnings
  2. warnings.filterwarnings("ignore")import math
  3. import torch
  4. import torch.nn as nn
  5. import torch.nn.functional as F
  6. from functools import partial
  7. from typing import Optional, Callable
  8. from timm.models.layers import DropPath, to_2tuple, trunc_normal_
  9. from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, selective_scan_ref
  10. from einops import rearrange, repeat
  11. """
  12. 最近,选择性结构化状态空间模型,特别是改进版本的Mamba,在线性复杂度的远程依赖建模方面表现出了巨大的潜力。
  13. 然而,标准Mamba在低级视觉方面仍然面临一定的挑战,例如局部像素遗忘和通道冗余。在这项工作中,我们引入了局部增强和通道注意力来改进普通 Mamba。
  14. 通过这种方式,我们利用了局部像素相似性并减少了通道冗余。大量的实验证明了我们方法的优越性。
  15. """
  16. NEG_INF =-1000000
  17. class ChannelAttention(nn.Module):
  18. """Channel attention used in RCAN.
  19. Args:
  20. num_feat (int): Channel number of intermediate features.
  21. squeeze_factor (int): Channel squeeze factor. Default: 16.
  22. """
  23. def __init__(self, num_feat, squeeze_factor=16):
  24. super(ChannelAttention, self).__init__()
  25. self.attention = nn.Sequential(
  26. nn.AdaptiveAvgPool2d(1),
  27. nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
  28. nn.ReLU(inplace=True),
  29. nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
  30. nn.Sigmoid())
  31. def forward(self, x):
  32. y = self.attention(x)return x * y
  33. class CAB(nn.Module):
  34. def __init__(self, num_feat, is_light_sr= False, compress_ratio=3,squeeze_factor=30):
  35. super(CAB, self).__init__()if is_light_sr: # we use depth-wise conv for light-SR to achieve more efficient
  36. self.cab = nn.Sequential(
  37. nn.Conv2d(num_feat, num_feat, 3, 1, 1, groups=num_feat),
  38. ChannelAttention(num_feat, squeeze_factor))
  39. else: # for classic SR
  40. self.cab = nn.Sequential(
  41. nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
  42. nn.GELU(),
  43. nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
  44. ChannelAttention(num_feat, squeeze_factor))
  45. def forward(self, x):
  46. return self.cab(x)
  47. class Mlp(nn.Module):
  48. def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
  49. super().__init__()
  50. out_features = out_features or in_features
  51. hidden_features = hidden_features or in_features
  52. self.fc1 = nn.Linear(in_features, hidden_features)
  53. self.act = act_layer()
  54. self.fc2 = nn.Linear(hidden_features, out_features)
  55. self.drop = nn.Dropout(drop)
  56. def forward(self, x):
  57. x = self.fc1(x)
  58. x = self.act(x)
  59. x = self.drop(x)
  60. x = self.fc2(x)
  61. x = self.drop(x)return x
  62. class DynamicPosBias(nn.Module):
  63. def __init__(self, dim, num_heads):
  64. super().__init__()
  65. self.num_heads = num_heads
  66. self.pos_dim = dim // 4
  67. self.pos_proj = nn.Linear(2, self.pos_dim)
  68. self.pos1 = nn.Sequential(
  69. nn.LayerNorm(self.pos_dim),
  70. nn.ReLU(inplace=True),
  71. nn.Linear(self.pos_dim, self.pos_dim),
  72. )
  73. self.pos2 = nn.Sequential(
  74. nn.LayerNorm(self.pos_dim),
  75. nn.ReLU(inplace=True),
  76. nn.Linear(self.pos_dim, self.pos_dim))
  77. self.pos3 = nn.Sequential(
  78. nn.LayerNorm(self.pos_dim),
  79. nn.ReLU(inplace=True),
  80. nn.Linear(self.pos_dim, self.num_heads))
  81. def forward(self, biases):
  82. pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))return pos
  83. def flops(self, N):
  84. flops = N * 2 * self.pos_dim
  85. flops += N * self.pos_dim * self.pos_dim
  86. flops += N * self.pos_dim * self.pos_dim
  87. flops += N * self.pos_dim * self.num_heads
  88. return flops
  89. class Attention(nn.Module):
  90. r""" Multi-head self attention module with dynamic position bias.
  91. Args:
  92. dim (int): Number of input channels.
  93. num_heads (int): Number of attention heads.
  94. qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
  95. qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5ifset
  96. attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
  97. proj_drop (float, optional): Dropout ratio of output. Default: 0.0"""
  98. def __init__(self, dim, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.,
  99. position_bias=True):
  100. super().__init__()
  101. self.dim = dim
  102. self.num_heads = num_heads
  103. head_dim = dim // num_heads
  104. self.scale = qk_scale or head_dim ** -0.5
  105. self.position_bias = position_bias
  106. if self.position_bias:
  107. self.pos = DynamicPosBias(self.dim // 4, self.num_heads)
  108. self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
  109. self.attn_drop = nn.Dropout(attn_drop)
  110. self.proj = nn.Linear(dim, dim)
  111. self.proj_drop = nn.Dropout(proj_drop)
  112. self.softmax = nn.Softmax(dim=-1)
  113. def forward(self, x, H, W, mask=None):
  114. """
  115. Args:
  116. x: input features with shape of (num_groups*B, N, C)
  117. mask: (0/-inf) mask with shape of (num_groups, Gh*Gw, Gh*Gw) or None
  118. H: height of each group
  119. W: width of each group
  120. """
  121. group_size =(H, W)
  122. B_, N, C = x.shape
  123. assert H * W == N
  124. qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4).contiguous()
  125. q, k, v= qkv[0], qkv[1], qkv[2]
  126. q = q * self.scale
  127. attn =(q @ k.transpose(-2, -1))# (B_, self.num_heads, N, N), N = H*Wif self.position_bias:
  128. # generate mother-set
  129. position_bias_h = torch.arange(1 - group_size[0], group_size[0], device=attn.device)
  130. position_bias_w = torch.arange(1 - group_size[1], group_size[1], device=attn.device)
  131. biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))# 2, 2Gh-1, 2W2-1
  132. biases = biases.flatten(1).transpose(0, 1).contiguous().float()# (2h-1)*(2w-1) 2# get pair-wise relative position index for each token inside the window
  133. coords_h = torch.arange(group_size[0], device=attn.device)
  134. coords_w = torch.arange(group_size[1], device=attn.device)
  135. coords = torch.stack(torch.meshgrid([coords_h, coords_w]))# 2, Gh, Gw
  136. coords_flatten = torch.flatten(coords, 1)# 2, Gh*Gw
  137. relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]# 2, Gh*Gw, Gh*Gw
  138. relative_coords = relative_coords.permute(1, 2, 0).contiguous()# Gh*Gw, Gh*Gw, 2
  139. relative_coords[:, :, 0]+= group_size[0] - 1# shift to start from 0
  140. relative_coords[:, :, 1]+= group_size[1] - 1
  141. relative_coords[:, :, 0] *=2 * group_size[1] - 1
  142. relative_position_index = relative_coords.sum(-1)# Gh*Gw, Gh*Gw
  143. pos = self.pos(biases)# 2Gh-1 * 2Gw-1, heads# select position bias
  144. relative_position_bias = pos[relative_position_index.view(-1)].view(
  145. group_size[0] * group_size[1], group_size[0] * group_size[1], -1)# Gh*Gw,Gh*Gw,nH
  146. relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()# nH, Gh*Gw, Gh*Gw
  147. attn = attn + relative_position_bias.unsqueeze(0)if mask is not None:
  148. nP = mask.shape[0]
  149. attn = attn.view(B_ // nP, nP, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)# (B, nP, nHead, N, N)
  150. attn = attn.view(-1, self.num_heads, N, N)
  151. attn = self.softmax(attn)
  152. else:
  153. attn = self.softmax(attn)
  154. attn = self.attn_drop(attn)
  155. x =(attn @ v).transpose(1, 2).reshape(B_, N, C)
  156. x = self.proj(x)
  157. x = self.proj_drop(x)return x
  158. class SS2D(nn.Module):
  159. def __init__(
  160. self,
  161. d_model,
  162. d_state=16,
  163. d_conv=3,
  164. expand=2.,
  165. dt_rank="auto",
  166. dt_min=0.001,
  167. dt_max=0.1,
  168. dt_init="random",
  169. dt_scale=1.0,
  170. dt_init_floor=1e-4,
  171. dropout=0.,
  172. conv_bias=True,
  173. bias=False,
  174. device=None,
  175. dtype=None,
  176. **kwargs,
  177. ):
  178. factory_kwargs ={"device": device, "dtype": dtype}
  179. super().__init__()
  180. self.d_model = d_model
  181. self.d_state = d_state
  182. self.d_conv = d_conv
  183. self.expand =expand
  184. self.d_inner = int(self.expand * self.d_model)
  185. self.dt_rank = math.ceil(self.d_model / 16)if dt_rank =="auto"else dt_rank
  186. self.in_proj = nn.Linear(self.d_model, self.d_inner * 2, bias=bias, **factory_kwargs)
  187. self.conv2d = nn.Conv2d(in_channels=self.d_inner,
  188. out_channels=self.d_inner,
  189. groups=self.d_inner,
  190. bias=conv_bias,
  191. kernel_size=d_conv,
  192. padding=(d_conv - 1) // 2,
  193. **factory_kwargs,
  194. )
  195. self.act = nn.SiLU()
  196. self.x_proj =(
  197. nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
  198. nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
  199. nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
  200. nn.Linear(self.d_inner, (self.dt_rank + self.d_state * 2), bias=False, **factory_kwargs),
  201. )
  202. self.x_proj_weight = nn.Parameter(torch.stack([t.weight fortin self.x_proj], dim=0))# (K=4, N, inner)
  203. del self.x_proj
  204. self.dt_projs =(
  205. self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
  206. **factory_kwargs),
  207. self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
  208. **factory_kwargs),
  209. self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
  210. **factory_kwargs),
  211. self.dt_init(self.dt_rank, self.d_inner, dt_scale, dt_init, dt_min, dt_max, dt_init_floor,
  212. **factory_kwargs),
  213. )
  214. self.dt_projs_weight = nn.Parameter(torch.stack([t.weight fortin self.dt_projs], dim=0))# (K=4, inner, rank)
  215. self.dt_projs_bias = nn.Parameter(torch.stack([t.bias fortin self.dt_projs], dim=0))# (K=4, inner)
  216. del self.dt_projs
  217. self.A_logs = self.A_log_init(self.d_state, self.d_inner, copies=4, merge=True)# (K=4, D, N)
  218. self.Ds = self.D_init(self.d_inner, copies=4, merge=True)# (K=4, D, N)
  219. self.selective_scan = selective_scan_fn
  220. self.out_norm = nn.LayerNorm(self.d_inner)
  221. self.out_proj = nn.Linear(self.d_inner, self.d_model, bias=bias, **factory_kwargs)
  222. self.dropout = nn.Dropout(dropout)if dropout >0. else None
  223. @staticmethod
  224. def dt_init(dt_rank, d_inner, dt_scale=1.0, dt_init="random", dt_min=0.001, dt_max=0.1, dt_init_floor=1e-4,
  225. **factory_kwargs):
  226. dt_proj = nn.Linear(dt_rank, d_inner, bias=True, **factory_kwargs)# Initialize special dt projection to preserve variance at initialization
  227. dt_init_std = dt_rank ** -0.5 * dt_scale
  228. if dt_init =="constant":
  229. nn.init.constant_(dt_proj.weight, dt_init_std)elif dt_init =="random":
  230. nn.init.uniform_(dt_proj.weight, -dt_init_std, dt_init_std)
  231. else:
  232. raise NotImplementedError
  233. # Initialize dt bias so that F.softplus(dt_bias) is between dt_min and dt_max
  234. dt = torch.exp(
  235. torch.rand(d_inner, **factory_kwargs) * (math.log(dt_max) - math.log(dt_min))
  236. + math.log(dt_min)).clamp(min=dt_init_floor)# Inverse of softplus: https://github.com/pytorch/pytorch/issues/72759
  237. inv_dt = dt + torch.log(-torch.expm1(-dt))
  238. with torch.no_grad():
  239. dt_proj.bias.copy_(inv_dt)# Our initialization would set all Linear.bias to zero, need to mark this one as _no_reinit
  240. dt_proj.bias._no_reinit = True
  241. return dt_proj
  242. @staticmethod
  243. def A_log_init(d_state, d_inner, copies=1, device=None, merge=True):
  244. # S4D real initialization
  245. A = repeat(
  246. torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
  247. "n -> d n",
  248. d=d_inner,
  249. ).contiguous()
  250. A_log = torch.log(A)# Keep A_log in fp32if copies >1:
  251. A_log = repeat(A_log, "d n -> r d n", r=copies)if merge:
  252. A_log = A_log.flatten(0, 1)
  253. A_log = nn.Parameter(A_log)
  254. A_log._no_weight_decay = True
  255. return A_log
  256. @staticmethod
  257. def D_init(d_inner, copies=1, device=None, merge=True):
  258. # D "skip" parameter
  259. D = torch.ones(d_inner, device=device)if copies >1:
  260. D = repeat(D, "n1 -> r n1", r=copies)if merge:
  261. D = D.flatten(0, 1)
  262. D = nn.Parameter(D)# Keep in fp32
  263. D._no_weight_decay = True
  264. return D
  265. def forward_core(self, x: torch.Tensor):
  266. B, C, H, W = x.shape
  267. L = H * W
  268. K =4
  269. x_hwwh = torch.stack([x.view(B, -1, L), torch.transpose(x, dim0=2, dim1=3).contiguous().view(B, -1, L)], dim=1).view(B, 2, -1, L)
  270. xs = torch.cat([x_hwwh, torch.flip(x_hwwh, dims=[-1])], dim=1)# (1, 4, 192, 3136)
  271. x_dbl = torch.einsum("b k d l, k c d -> b k c l", xs.view(B, K, -1, L), self.x_proj_weight)
  272. dts, Bs, Cs = torch.split(x_dbl, [self.dt_rank, self.d_state, self.d_state], dim=2)
  273. dts = torch.einsum("b k r l, k d r -> b k d l", dts.view(B, K, -1, L), self.dt_projs_weight)
  274. xs = xs.float().view(B, -1, L)
  275. dts = dts.contiguous().float().view(B, -1, L)# (b, k * d, l)
  276. Bs = Bs.float().view(B, K, -1, L)
  277. Cs = Cs.float().view(B, K, -1, L)# (b, k, d_state, l)
  278. Ds = self.Ds.float().view(-1)
  279. As = -torch.exp(self.A_logs.float()).view(-1, self.d_state)
  280. dt_projs_bias = self.dt_projs_bias.float().view(-1)# (k * d)
  281. out_y = self.selective_scan(
  282. xs, dts,
  283. As, Bs, Cs, Ds, z=None,
  284. delta_bias=dt_projs_bias,
  285. delta_softplus=True,
  286. return_last_state=False,
  287. ).view(B, K, -1, L)
  288. assert out_y.dtype == torch.float
  289. inv_y = torch.flip(out_y[:, 2:4], dims=[-1]).view(B, 2, -1, L)
  290. wh_y = torch.transpose(out_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)
  291. invwh_y = torch.transpose(inv_y[:, 1].view(B, -1, W, H), dim0=2, dim1=3).contiguous().view(B, -1, L)return out_y[:, 0], inv_y[:, 0], wh_y, invwh_y
  292. def forward(self, x: torch.Tensor, **kwargs):
  293. B, H, W, C = x.shape
  294. xz = self.in_proj(x)
  295. x, z = xz.chunk(2, dim=-1)
  296. x = x.permute(0, 3, 1, 2).contiguous()
  297. x = self.act(self.conv2d(x))
  298. y1, y2, y3, y4 = self.forward_core(x)
  299. assert y1.dtype == torch.float32
  300. y = y1 + y2 + y3 + y4
  301. y = torch.transpose(y, dim0=1, dim1=2).contiguous().view(B, H, W, -1)
  302. y = self.out_norm(y)
  303. y = y * F.silu(z)
  304. out = self.out_proj(y)if self.dropout is not None:
  305. out = self.dropout(out)return out
  306. class VSSBlock(nn.Module):
  307. def __init__(
  308. self,
  309. hidden_dim: int =0,
  310. drop_path: float =0,
  311. norm_layer: Callable[..., torch.nn.Module]= partial(nn.LayerNorm, eps=1e-6),
  312. attn_drop_rate: float =0,
  313. d_state: int =16,
  314. expand: float =2.,
  315. is_light_sr: bool = False,
  316. **kwargs,
  317. ):
  318. super().__init__()
  319. self.ln_1 = norm_layer(hidden_dim)
  320. self.self_attention = SS2D(d_model=hidden_dim, d_state=d_state,expand=expand,dropout=attn_drop_rate, **kwargs)
  321. self.drop_path = DropPath(drop_path)self.skip_scale= nn.Parameter(torch.ones(hidden_dim))
  322. self.conv_blk = CAB(hidden_dim,is_light_sr)
  323. self.ln_2 = nn.LayerNorm(hidden_dim)
  324. self.skip_scale2 = nn.Parameter(torch.ones(hidden_dim))
  325. def forward(self, input, x_size):
  326. # x [B,HW,C]
  327. B, L, C = input.shape
  328. input = input.view(B, *x_size, C).contiguous()# [B,H,W,C]
  329. x = self.ln_1(input)
  330. x = input*self.skip_scale + self.drop_path(self.self_attention(x))
  331. x = x*self.skip_scale2 + self.conv_blk(self.ln_2(x).permute(0, 3, 1, 2).contiguous()).permute(0, 2, 3, 1).contiguous()
  332. x = x.view(B, -1, C).contiguous()return x
  333. if __name__ =='__main__':# 初始化VSSBlock模块,hidden_dim128
  334. block = VSSBlock(hidden_dim=128, drop_path=0.1, attn_drop_rate=0.1, d_state=16, expand=2.0, is_light_sr=False)# 将模块转移到合适的设备上
  335. device = torch.device("cuda"if torch.cuda.is_available()else"cpu")
  336. block = block.to(device)# 生成随机输入张量,尺寸为[B, H*W, C],这里模拟的是批次大小为4,每个图像的尺寸是32x32,通道数为128
  337. B, H, W, C =4, 32, 32, 128
  338. input_tensor = torch.rand(B, H * W, C).to(device)# 计算输出
  339. output_tensor = block(input_tensor, (H, W))# 打印输入和输出张量的尺寸
  340. print("Input tensor size:", input_tensor.size())
  341. print("Output tensor size:", output_tensor.size())
运行结果

在这里插入图片描述


第三章:经典文献阅读与追踪

Mamba原文:Mamba: Linear-Time Sequence Modeling with Selective State Spaces

经典论文

  1. Vision Mamba@Vision Mamba: Efficient Visual Representation Learning with Bidirectional State Space Model
  2. MambaIR@MambaIR: A Simple Baseline for Image Restoration with State-Space Model
  3. U-Mamba@U-Mamba: Enhancing Long-range Dependency for Biomedical Image Segmentation

Mamba系列论文追踪

  1. Github

链接会分享不同领域基于

  1. Mamba

结构的论文

Mamba_State_Space_Model_Paper_List Public:https://github.com/Event-AHU/Mamba_State_Space_Model_Paper_List


第四章:Mamba理论与分析

  1. 我们以一篇文章FusionMamba来理解Mamba

FusionMamba: Efficient Image Fusion with State Space Model【文献阅读】

Mamba模块

借用该论文的图3来一起学习一下Mamba模块的结构:

在这里插入图片描述
其中,最左边的就是Mamba模块。Vision Mamba模块要对特征图进行特征提取。因此,我们期望经过Mamba模块后的特征图的大小不变。

第一部分:把输入的特征图F_in,其维度为H,W,C送入LayerNorm层,映射得到两个不同的特征X和Z,它们的维度不变为H,W,C。
第二部分:对X沿着4个不同的方向进行Fatten展平得到1维的特征向量,这4个方向特征向量的维度是HW,C

  1. 这儿和Transformer的变换类似,转换成TOKEN,然后再去进行后续计算

。4个不同方向的展平方式,如上图最右边所示,就是从左到右、从上到下四个方向。
第三部分:将4个不同方向的1维特征向量送入SSM模块进行特征提取,看来SSM模块就是Mamba模块的核心了,这个我们将在后文对它进行详细的解读。
第四部分:将输出的特征向量其维度为HW,C,经过unflatten就是还原成特征图维度为H,W,C后将4个方向的特征图加起来,进行充分的融合得到特征Y。
第五部分:对最初的特征Z经过SiLU进行非线性映射,作为权重或者注意力与融合的特征图Y进行激活或者加权得到显著性的特征。最后将特征经过1×1的卷积进行映射后与输入的特征做一个残差得到最终的输出特征F_out。

关键的SSM算法

按照该论文给出的流程图,我们来对SSM算法进行一个充分的理解。如下图最左边,右边不用管是作者对其的改进。

在这里插入图片描述

  1. SSM Block未完待续...

第五章:总结和展望

  1. 2024年04月29日16:57:45,今天已完成环境的安装与即插即用模块实例化和相关论文的分享;在近期会充分学习Mamba后对其理论进行分享,帮助快速简要理解原文Mamba相关理论。
  2. 2024年05月02日15:56:32,今天基于一篇FusionMamba的论文补充了Mamba模块的基础知识,后面将重点介绍其中的SSM模块,就会完成本博客的分享。


本文转载自: https://blog.csdn.net/weixin_43312117/article/details/138316513
版权归原作者 陈嘿萌 所有, 如有侵权,请联系我们删除。

“深入浅出一文图解Vision Mamba(ViM)”的评论:

还没有评论