【语义分割】ST_Unet论文 逐步代码解读
文章目录
一、代码整体解读
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-UAEkMEUl-1678889964762)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310143528765.png)]](https://img-blog.csdnimg.cn/adc0a302f2694975a92a1a5c9392729b.png#pic_center)
主要工程文件为这5个
分别作用为:
- 构造相应的deform 卷积
- DCNN的残差网络
- 编写相应的配置文件,可以改变相应参数
- 模型的主函数和主框架
- 模型的连接部分
二、辅助Decode代码框架
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-xB9iqcKa-1678889964763)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310145258457.png)]](https://img-blog.csdnimg.cn/75ba89a9122d4eb2b87bf76b25ddfb37.png#pic_center)
代码框架由3部分组成,encode,decode和decode中将图像还原成语义分割预测图
- Transformer(config, img_size) 组成编码部分,包含主编码的DCNN和辅助的transformer
- DecoderCup(config)组成解码部分,图像还原为[N,64,128,128]
- SegmentationHead将图像变成6分类的[N,6,256,256]的图像
2.1 混合transformer和cnn的模型
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-C1JDIQD6-1678889964763)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310150618037.png)]](https://img-blog.csdnimg.cn/e162503d3a604de1a7682bbb57332881.png#pic_center)
整体思路是这样的,decode一共分为4个阶段
主要用空的数组来保存每一个阶段的输出值,与DCNN在每一个阶段通过RAM进行连接
在class TransResNetV2(nn.Module)函数中进行相应的具体编写
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-fFrLblTw-1678889964763)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310152428267.png)]](https://img-blog.csdnimg.cn/7801673efed34f63ad442f4619ddb98e.png#pic_center)
相应的RAM操作示意图
An和Sn分别表示第n阶段主编码器和辅助编码器的输出
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Ua68cdLE-1678889964764)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310152957811.png)]](https://img-blog.csdnimg.cn/5ac72b023b6d402d8cc25e3cdfa2a62a.png#pic_center)
一共组成分为4部分,将每一层都进行相应的整合,最后放在数组里面
2.2 Swin transformer 部分
将读入的数据进行打平操作
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OYQ7SY3i-1678889964764)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310154540561.png)]](https://img-blog.csdnimg.cn/32c07ee8f766467bb942c7d6203e9b14.png#pic_center)
embeddings(trans_x)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Y41JX67e-1678889964764)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310161441181.png)]](https://img-blog.csdnimg.cn/58aec6c1f1454b138978796fe9bf68d8.png#pic_center)
这部分操作的一般情况下的Swin transformer一样,同样满足(2,2,6,2)的层数结构,只不过是,加入了相应的残差结构,经过了扁平化操作后的数据类型为[12, 4096, 96]
具体的transformer操作在这部分进行
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-3uO7glYk-1678889964764)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310161944441.png)]](https://img-blog.csdnimg.cn/1074e4542f994652b6c12bca8ccaf01f.png#pic_center)
在这个函数中主要是transformer块和SIM的残差组合
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-VWJFuFdd-1678889964765)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310163215108.png)]](https://img-blog.csdnimg.cn/1ab25816b47d403c95bb173df255d8c5.png#pic_center)
在这步之后,x就可以就是组成的tranform块输出的格式,其中是由shift_size来进行判断是W-MSA,还是SW-MSA,来进行的窗口移动,还是就单纯的结构的划分
shift_size=0if(i %2==0)else window_size //2,# 判断是不是SW_MSA
起了决定性的判断作用
这段代码进行执行4次,将每次执行的结果进行保存
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-X8GyKSQg-1678889964765)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310170340544.png)]](https://img-blog.csdnimg.cn/1d239bcce806496da1ea9295eaf69621.png#pic_center)
2.3 FCM 部分
if(i_layer < self.num_layers -1):
trans_x = self.Down_features[i_layer](trans_x)
在每一个Swin transformer阶段都进行了下采样,除了最后一个阶段
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4UiGLHwU-1678889964765)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310172543159.png)]](https://img-blog.csdnimg.cn/0eb9f041055a478e9198931fcbf2118e.png#pic_center)
结构示意图
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-hp318ekr-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310172916320.png)]](https://img-blog.csdnimg.cn/e0a3775449e541fbbd3b5daa3fbc0fc2.png#pic_center)
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-5qY2UKUs-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310173458842.png)]](https://img-blog.csdnimg.cn/15c82bffb1e1498385874512f62159a9.png#pic_center)
整体的代码逻辑都是按照这个思路来的,来进行的整合和结合
三、主Decode代码框架
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-p157EjKs-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310174011667.png)]](https://img-blog.csdnimg.cn/c18b94dd08204074862aaf06a12073b0.png#pic_center)
首先进行的是root函数,主打的是一个对图片进行预先处理
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-JhTJnNgQ-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310190150355.png)]](https://img-blog.csdnimg.cn/369948868ffe4f108ced1e9a15a78248.png#pic_center)
对图片进行相应的变形,主要还是三步走的对策,卷积,归一,relu。进行DCNN卷积网络时基本都是这样进行的
3.1 基本卷积模块
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-bNtANlad-1678889964766)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310190743797.png)]](https://img-blog.csdnimg.cn/d1e257cb2ddd4563beb95f42f59a4b89.png#pic_center)
将这两个归为一个操作,body里面是几个卷积的模块config.resnet.num_layers = (3, 4, 6, 3)组成的,重复的次数由设定好的值来进行重复
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-DgRpAOxk-1678889964767)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310191345796.png)]](https://img-blog.csdnimg.cn/1c09fcf4dc7b4a4b9d8f35eca56efb64.png#pic_center)
PreActBottleneck(nn.Module) 里面的值就是很单纯的DCNN的卷积网络的堆积
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-Zc9qXqRd-1678889964767)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310191640918.png)]](https://img-blog.csdnimg.cn/a3b9618fbc434d5d8d81288f3a60fe2a.png#pic_center)
在这里DCN加了一个, DeformConv2d,这个函数是自己编写的,一个可变形的卷积操作,其实他本质上和普通的卷积操作一样
后面也是相同的操作,通过RAM模块将相应的结果组合在一起
3.2 RAM
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-pvpyR3hG-1678889964767)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310211839414.png)]](https://img-blog.csdnimg.cn/07c3c76d86ff4a48afac5e21d2503530.png#pic_center)
输入分为了主编码器和辅助编码器,总共的结合组成为3种,将不同变化的进行拼接
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-OPLAKkjh-1678889964767)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310212655800.png)]](https://img-blog.csdnimg.cn/38066d8a579f4ea5a85d15eb936f5981.png#pic_center)
相应参数:
- x 原始参数
- short 经过了注意力通道
- s1 tranformer辅助通道过来的数据
3.3 输出参数
输出参数值主要分为两类:
- 结合所有参数的X [N, 32, 256, 256]
- 每个阶段提取出来的特征数据features
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-9rD0FkJk-1678889964768)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310213311694.png)]](https://img-blog.csdnimg.cn/9c496d28ea154a278771f16a254f4af8.png#pic_center)
将这两个数据进行带入Encode,进行解码,可以逐步还原成原始图像
四、Encode代码
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-nVjVVQLc-1678889964768)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310213517652.png)]](https://img-blog.csdnimg.cn/7713d41b6c9a48aeb24c357cb4bec047.png#pic_center)
代码主要分为两步来实现:
- x 的卷积上采样
- x与skip的融合后,进行相应的卷积操作
skip是每个特征层的进过RAM后的保存数据,所有的融合卷积操作在block中完成
4.1 block函数解析
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-ksA8GBvs-1678889964768)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310233943010.png)]](https://img-blog.csdnimg.cn/f8517f22b4b349b1b5464a3b32a7b94b.png#pic_center)
在连接阶段主要是conv1和conv2,这两个函数, 进行上采样来保存维度一致,使他可以cat在一起 conv3和conv4在连接完成后,进行相应的上采样环节来使图像还原成原来的[n, 6, 256, 256]
4.2 上采样还原
![[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-vdcEgMW9-1678889964768)(C:\Users\isipa\AppData\Roaming\Typora\typora-user-images\image-20230310235005427.png)]](https://img-blog.csdnimg.cn/f15ef4a91398467793581144f7fa7477.png#pic_center)
这部分代码将这里独立出去了
在这里x的输入参数应该是(N,16,256,256)
在进行了一次卷积和上采用后,就恢复成了原始图像
版权归原作者 川川子溢 所有, 如有侵权,请联系我们删除。