0


基于深度学习神经网络的AI图像PSD去雾系统源码

第一步:PSD介绍

  1. 以往的研究主要集中在具有合成模糊图像的训练模型上,当模型用于真实世界的模糊图像时,会导致性能下降。
  2. 为了解决上述问题,提高去雾的泛化性能,作者提出了一种Principled Synthetic-to-real Dehazing (PSD)框架。
  3. 本文提出的PSD适用于将现有的去雾模型推广到实际领域,包括两个阶段:**有监督的预训练**和**无监督的微调**。

** 预训练阶段**,作者将选定的去雾模型主干修改为一个基于物理模型的网络,并用合成数据训练该网络。利用设计良好的主干,我们可以得到一个预先训练的模型,在合成域上具有良好的去雾性能。

** 微调阶段,作者利用真实的模糊图像以无监督**的方式训练模型。

本文的贡献:

  1. 作者将真实世界的去雾任务重新定义为一个合成到真实的泛化框架:首先一个在合成配对数据上预先训练的去雾模型主干,真实的模糊图像随后将被利用以一种无监督的方式微调模型。PSD易于使用,可以以大多数深度去雾模型为骨干。
  2. 由于没有清晰的真实图像作为监督,作者利用几个流行的、有充分根据的物理先验来指导微调。作者将它们合并成一个预先的损失committee,作为具体任务的代理指导,这一部分也是PSD的核心。
  3. 性能达到SOTA

第二步:PSD网络结构

  1. 首先对两个框架大的方向做一个整体概述:

** Pre-training**

  1. 首先采用目前性能最好的框架之一作为网络的主干
  2. 然后我们将主干修改为一个基于物理的网络,根据一个单一的雾图同时生成干净的图像 J,传输图 t 和大气光 A,为了共同优化这三个分量,作者加入了一个重建损失,它引导网络输出服从物理散射模型。
  3. 在这个阶段,只使用标记的合成数据进行训练,最终得到一个在合成域上预训练的模型。

** Fine-tuning**

  1. 作者利用**未标记的真实数据**将预训练模型从合成域推广到真实域。受去雾强物理背景的启发,作者认为一个高质量的无雾图像应该**遵循一些特定的统计规则**,这些规则可以从图像先验中推导出来。此外,**单一先验提供的物理知识并不总是可靠的**,所以作者的目标是找到多个先验的组合,希望它们能够相互补充。
  2. 基于上述,作者设计了一个先验损失committee来作为任务特定的代理指导,用于**训练未标记的真实数据**。
  3. 此外,作者应用了一种learning without forgetting (LwF)的方法,该方法通过将原始任务的训练数据(即合成的模糊图像)通过网络运转到同真实的模糊数据一起,从而强行使得模型记忆合成领域的知识。

第三步:模型代码展示

  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class BlockUNet1(nn.Module):
  5. def __init__(self, in_channels, out_channels, upsample=False, relu=False, drop=False, bn=True):
  6. super(BlockUNet1, self).__init__()
  7. self.conv = nn.Conv2d(in_channels, out_channels, 4, 2, 1, bias=False)
  8. self.deconv = nn.ConvTranspose2d(in_channels, out_channels, 4, 2, 1, bias=False)
  9. self.dropout = nn.Dropout2d(0.5)
  10. self.batch = nn.InstanceNorm2d(out_channels)
  11. self.upsample = upsample
  12. self.relu = relu
  13. self.drop = drop
  14. self.bn = bn
  15. def forward(self, x):
  16. if self.relu == True:
  17. y = F.relu(x)
  18. elif self.relu == False:
  19. y = F.leaky_relu(x, 0.2)
  20. if self.upsample == True:
  21. y = self.deconv(y)
  22. if self.bn == True:
  23. y = self.batch(y)
  24. if self.drop == True:
  25. y = self.dropout(y)
  26. elif self.upsample == False:
  27. y = self.conv(y)
  28. if self.bn == True:
  29. y = self.batch(y)
  30. if self.drop == True:
  31. y = self.dropout(y)
  32. return y
  33. class G2(nn.Module):
  34. def __init__(self, in_channels, out_channels):
  35. super(G2, self).__init__()
  36. self.conv = nn.Conv2d(in_channels, 8, 4, 2, 1, bias=False)
  37. self.layer1 = BlockUNet1(8, 16)
  38. self.layer2 = BlockUNet1(16, 32)
  39. self.layer3 = BlockUNet1(32, 64)
  40. self.layer4 = BlockUNet1(64, 64)
  41. self.layer5 = BlockUNet1(64, 64)
  42. self.layer6 = BlockUNet1(64, 64)
  43. self.layer7 = BlockUNet1(64, 64)
  44. self.dlayer7 = BlockUNet1(64, 64, True, True, True, False)
  45. self.dlayer6 = BlockUNet1(128, 64, True, True, True)
  46. self.dlayer5 = BlockUNet1(128, 64, True, True, True)
  47. self.dlayer4 = BlockUNet1(128, 64, True, True)
  48. self.dlayer3 = BlockUNet1(128, 32, True, True)
  49. self.dlayer2 = BlockUNet1(64, 16, True, True)
  50. self.dlayer1 = BlockUNet1(32, 8, True, True)
  51. self.relu = nn.ReLU()
  52. self.dconv = nn.ConvTranspose2d(16, out_channels, 4, 2, 1, bias=False)
  53. self.lrelu = nn.LeakyReLU(0.2)
  54. def forward(self, x):
  55. y1 = self.conv(x)
  56. y2 = self.layer1(y1)
  57. y3 = self.layer2(y2)
  58. y4 = self.layer3(y3)
  59. y5 = self.layer4(y4)
  60. y6 = self.layer5(y5)
  61. y7 = self.layer6(y6)
  62. y8 = self.layer7(y7)
  63. dy8 = self.dlayer7(y8)
  64. concat7 = torch.cat([dy8, y7], 1)
  65. dy7 = self.dlayer6(concat7)
  66. concat6 = torch.cat([dy7, y6], 1)
  67. dy6 = self.dlayer5(concat6)
  68. concat5 = torch.cat([dy6, y5], 1)
  69. dy5 = self.dlayer4(concat5)
  70. concat4 = torch.cat([dy5, y4], 1)
  71. dy4 = self.dlayer3(concat4)
  72. concat3 = torch.cat([dy4, y3], 1)
  73. dy3 = self.dlayer2(concat3)
  74. concat2 = torch.cat([dy3, y2], 1)
  75. dy2 = self.dlayer1(concat2)
  76. concat1 = torch.cat([dy2, y1], 1)
  77. out = self.relu(concat1)
  78. out = self.dconv(out)
  79. out = self.lrelu(out)
  80. return F.avg_pool2d(out, (out.shape[2], out.shape[3]))
  81. def default_conv(in_channels, out_channels, kernel_size, bias=True):
  82. return nn.Conv2d(in_channels, out_channels, kernel_size,padding=(kernel_size//2), bias=bias)
  83. class PALayer(nn.Module):
  84. def __init__(self, channel):
  85. super(PALayer, self).__init__()
  86. self.pa = nn.Sequential(
  87. nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
  88. nn.ReLU(inplace=True),
  89. nn.Conv2d(channel // 8, 1, 1, padding=0, bias=True),
  90. nn.Sigmoid()
  91. )
  92. def forward(self, x):
  93. y = self.pa(x)
  94. return x * y
  95. class CALayer(nn.Module):
  96. def __init__(self, channel):
  97. super(CALayer, self).__init__()
  98. self.avg_pool = nn.AdaptiveAvgPool2d(1)
  99. self.ca = nn.Sequential(
  100. nn.Conv2d(channel, channel // 8, 1, padding=0, bias=True),
  101. nn.ReLU(inplace=True),
  102. nn.Conv2d(channel // 8, channel, 1, padding=0, bias=True),
  103. nn.Sigmoid()
  104. )
  105. def forward(self, x):
  106. y = self.avg_pool(x)
  107. y = self.ca(y)
  108. return x * y
  109. class Block(nn.Module):
  110. def __init__(self, conv, dim, kernel_size,):
  111. super(Block, self).__init__()
  112. self.conv1=conv(dim, dim, kernel_size, bias=True)
  113. self.act1=nn.ReLU(inplace=True)
  114. self.conv2=conv(dim,dim,kernel_size,bias=True)
  115. self.calayer=CALayer(dim)
  116. self.palayer=PALayer(dim)
  117. def forward(self, x):
  118. res=self.act1(self.conv1(x))
  119. res=res+x
  120. res=self.conv2(res)
  121. res=self.calayer(res)
  122. res=self.palayer(res)
  123. res += x
  124. return res
  125. class Group(nn.Module):
  126. def __init__(self, conv, dim, kernel_size, blocks):
  127. super(Group, self).__init__()
  128. modules = [ Block(conv, dim, kernel_size) for _ in range(blocks)]
  129. modules.append(conv(dim, dim, kernel_size))
  130. self.gp = nn.Sequential(*modules)
  131. def forward(self, x):
  132. res = self.gp(x)
  133. res += x
  134. return res
  135. class FFANet(nn.Module):
  136. def __init__(self,gps,blocks,conv=default_conv):
  137. super(FFANet, self).__init__()
  138. self.gps=gps
  139. self.dim=64
  140. kernel_size=3
  141. pre_process = [conv(3, self.dim, kernel_size)]
  142. assert self.gps==3
  143. self.g1= Group(conv, self.dim, kernel_size,blocks=blocks)
  144. self.g2= Group(conv, self.dim, kernel_size,blocks=blocks)
  145. self.g3= Group(conv, self.dim, kernel_size,blocks=blocks)
  146. self.ca=nn.Sequential(*[
  147. nn.AdaptiveAvgPool2d(1),
  148. nn.Conv2d(self.dim*self.gps,self.dim//16,1,padding=0),
  149. nn.ReLU(inplace=True),
  150. nn.Conv2d(self.dim//16, self.dim*self.gps, 1, padding=0, bias=True),
  151. nn.Sigmoid()
  152. ])
  153. self.palayer=PALayer(self.dim)
  154. self.conv_J_1 = nn.Conv2d(64, 64, 3, 1, 1, bias=False)
  155. self.conv_J_2 = nn.Conv2d(64, 3, 3, 1, 1, bias=False)
  156. self.conv_T_1 = nn.Conv2d(64, 16, 3, 1, 1, bias=False)
  157. self.conv_T_2 = nn.Conv2d(16, 1, 3, 1, 1, bias=False)
  158. post_precess = [
  159. conv(self.dim, self.dim, kernel_size),
  160. conv(self.dim, 3, kernel_size)]
  161. self.pre = nn.Sequential(*pre_process)
  162. self.post = nn.Sequential(*post_precess)
  163. self.ANet = G2(3, 3)
  164. def forward(self, x1, x2=0, Val=False):
  165. x = self.pre(x1)
  166. res1=self.g1(x)
  167. res2=self.g2(res1)
  168. res3=self.g3(res2)
  169. w=self.ca(torch.cat([res1,res2,res3],dim=1))
  170. w=w.view(-1,self.gps,self.dim)[:,:,:,None,None]
  171. out=w[:,0,::]*res1+w[:,1,::]*res2+w[:,2,::]*res3
  172. out=self.palayer(out)
  173. out_J = self.conv_J_1(out)
  174. out_J = self.conv_J_2(out_J)
  175. out_J = out_J + x1
  176. out_T = self.conv_T_1(out)
  177. out_T = self.conv_T_2(out_T)
  178. if Val == False:
  179. out_A = self.ANet(x1)
  180. else:
  181. out_A = self.ANet(x2)
  182. out_I = out_T * out_J + (1 - out_T) * out_A
  183. #x=self.post(out)
  184. return out, out_J, out_T, out_A, out_I
  185. if __name__ == "__main__":
  186. net=FFA(gps=3,blocks=19)
  187. print(net)

第四步:运行

第五步:整个工程的内容

代码的下载路径(新窗口打开链接)基于深度学习神经网络的AI图像PSD去雾系统源码

​​

有问题可以私信或者留言,有问必答


本文转载自: https://blog.csdn.net/m0_59023219/article/details/138765064
版权归原作者 AI街潜水的八角 所有, 如有侵权,请联系我们删除。

“基于深度学习神经网络的AI图像PSD去雾系统源码”的评论:

还没有评论