0


pytorch 自编码器实现图像的降噪

自编码器

自动编码器是一种无监督的深度学习算法,它学习输入数据的编码表示,然后重新构造与输出相同的输入。它由编码器和解码器两个网络组成。编码器将高维输入压缩成低维潜在(也称为潜在代码或编码空间) ,以从中提取最相关的信息,而解码器则解压缩编码数据并重新创建原始输入。

自编码器的输入和输出应该尽可能的相似。

通过输入含有噪声的图像,编码器在编码的过程中会存在信息丢失,将输入和输出最相似的特征保留下来,通过解码器得到最后的输出。在这个转换的过程中实现了图像的去噪。

自编码器主要的用途其实是用于降维,将高维的数据编码为一组向量,解码器通过解码得到输出。

数据集导入可视化

  1. import torchvision
  2. import matplotlib.pyplot as plt
  3. from torch.utils.data import DataLoader
  4. import numpy as np
  5. import random
  6. import PIL.Image as Image
  7. import torchvision.transforms as transforms
  8. class AddPepperNoise(object):
  9. """增加椒盐噪声
  10. Args:
  11. snr (float): Signal Noise Rate
  12. p (float): 概率值,依概率执行该操作
  13. """
  14. def __init__(self, snr, p=0.9):
  15. assert isinstance(snr, float) and (isinstance(p, float))
  16. self.snr = snr
  17. self.p = p
  18. def __call__(self, img):
  19. """
  20. Args:
  21. img (PIL Image): PIL Image
  22. Returns:
  23. PIL Image: PIL image.
  24. """
  25. if random.uniform(0, 1) < self.p:
  26. img_ = np.array(img).copy()
  27. h, w = img_.shape
  28. signal_pct = self.snr
  29. noise_pct = (1 - self.snr)
  30. mask = np.random.choice((0, 1, 2), size=(h, w), p=[signal_pct, noise_pct/2., noise_pct/2.])
  31. img_[mask == 1] = 255 # 盐噪声
  32. img_[mask == 2] = 0 # 椒噪声
  33. return Image.fromarray(img_.astype('uint8'))
  34. else:
  35. return img
  36. class Gaussian_noise(object):
  37. """增加高斯噪声
  38. 此函数用将产生的高斯噪声加到图片上
  39. 传入:
  40. img : 原图
  41. mean : 均值
  42. sigma : 标准差
  43. 返回:
  44. gaussian_out : 噪声处理后的图片
  45. """
  46. def __init__(self, mean, sigma):
  47. self.mean = mean
  48. self.sigma = sigma
  49. def __call__(self, img):
  50. """
  51. Args:
  52. img (PIL Image): PIL Image
  53. Returns:
  54. PIL Image: PIL image.
  55. """
  56. # 将图片灰度标准化
  57. img_ = np.array(img).copy()
  58. img_ = img_ / 255.0
  59. # 产生高斯 noise
  60. noise = np.random.normal(self.mean, self.sigma, img_.shape)
  61. # 将噪声和图片叠加
  62. gaussian_out = img_ + noise
  63. # 将超过 1 的置 1,低于 0 的置 0
  64. gaussian_out = np.clip(gaussian_out, 0, 1)
  65. # 将图片灰度范围的恢复为 0-255
  66. gaussian_out = np.uint8(gaussian_out*255)
  67. # 将噪声范围搞为 0-255
  68. # noise = np.uint8(noise*255)
  69. return Image.fromarray(gaussian_out)
  70. train_datasets = torchvision.datasets.MNIST('./', train=True, download=True)
  71. test_datasets = torchvision.datasets.MNIST('./', train=False, download=True)
  72. print('训练集的数量', len(train_datasets))
  73. print('测试集的数量', len(test_datasets))
  74. train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)
  75. test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)
  76. transform=transforms.Compose([
  77. transforms.ToPILImage(),
  78. Gaussian_noise(0,0.1),
  79. AddPepperNoise(0.9)
  80. # transforms.ToTensor()
  81. ])
  82. print('训练集可视化')
  83. fig = plt.figure()
  84. for i in range(12):
  85. plt.subplot(3, 4, i + 1)
  86. img = train_datasets.train_data[i]
  87. label = train_datasets.train_labels[i]
  88. # noise = np.random.normal(0.1, 0.1, img.shape)
  89. # img=transform(img)
  90. plt.imshow(img, cmap='gray')
  91. plt.title(label)
  92. plt.xticks([])
  93. plt.yticks([])
  94. plt.show()

噪声图像

原始图像

模型的搭建

  1. import torch
  2. from torch import nn
  3. class AE(nn.Module):
  4. def __init__(self):
  5. super(AE, self).__init__()
  6. # [b, 784] => [b, 20]
  7. self.encoder = nn.Sequential(
  8. nn.Linear(784, 256),
  9. nn.ReLU(),
  10. nn.Linear(256, 64),
  11. nn.ReLU(),
  12. nn.Linear(64, 20),
  13. nn.ReLU()
  14. )
  15. # [b, 20] => [b, 784]
  16. self.decoder = nn.Sequential(
  17. nn.Linear(20, 64),
  18. nn.ReLU(),
  19. nn.Linear(64, 256),
  20. nn.ReLU(),
  21. nn.Linear(256, 784),
  22. nn.Sigmoid()
  23. )
  24. def forward(self, x):
  25. """
  26. :param x: [b, 1, 28, 28]
  27. :return:
  28. """
  29. batchsz = x.size(0)
  30. # flatten(打平)
  31. x = x.view(batchsz, 784)
  32. # encoder
  33. x = self.encoder(x)
  34. # decoder
  35. x = self.decoder(x)
  36. # reshape
  37. x = x.view(batchsz, 1, 28, 28)
  38. return x
  39. if __name__=='__main__':
  40. model=AE()
  41. input=torch.randn(1,28,28)
  42. input=input.view(1,-1)
  43. print('输入的维度',input.shape)
  44. encoder_out=model.encoder(input)
  45. print('编码器的输出',encoder_out.shape)
  46. out=model.decoder(encoder_out)
  47. print('解码器的输出',out.shape)

模型的训练

导入训练集训练的时候一定要将使用transforms将所有图像转换为tensor格式,这里的方法不同于tensorflow导入MNIST方法,如果不加transforms则图像的格式为列表类型,下面在训练的时候会报错。

在训练过程中添加噪声。分别添加了高斯噪声和椒盐噪声

  1. import torchvision
  2. from torch.utils.data import DataLoader
  3. import numpy as np
  4. import random,os
  5. import PIL.Image as Image
  6. import torchvision.transforms as transforms
  7. from torch import nn,optim
  8. import torch
  9. from models import AE
  10. from tqdm import tqdm
  11. train_datasets = torchvision.datasets.MNIST('./', train=True, download=True,transform=transforms.ToTensor())
  12. test_datasets = torchvision.datasets.MNIST('./', train=False, download=True,transform=transforms.ToTensor())
  13. print('训练集的数量', len(train_datasets))
  14. print('测试集的数量', len(test_datasets))
  15. train_loader = DataLoader(train_datasets, batch_size=100, shuffle=True)
  16. test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)
  17. #模型,优化器,损失函数
  18. device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  19. model=AE().to(device)
  20. criteon = nn.MSELoss()
  21. optimizer = optim.Adam(model.parameters(), lr=1e-3)
  22. ##导入预训练模型
  23. if os.path.exists('./model.pth') :
  24. # 如果存在已保存的权重,则加载
  25. checkpoint = torch.load('model.pth',map_location=lambda storage,loc:storage)
  26. model.load_state_dict(checkpoint['model_state_dict'])
  27. optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
  28. initepoch = checkpoint['epoch']
  29. loss = checkpoint['loss']
  30. else:
  31. initepoch=0
  32. #开始训练
  33. for epoch in range(initepoch, 50):
  34. with tqdm(total=(len(train_datasets)-len(train_datasets)), ncols=80) as t:
  35. t.set_description('epoch: {}/{}'.format(epoch, 50))
  36. running_loss = 0.0
  37. for i, data in enumerate(train_loader, 0):
  38. # get the inputs
  39. true_input, _ = data
  40. #生成均值为0,方差为0.1的高斯分布
  41. gaussian_noise=torch.normal(mean=0,std=0.1,size=true_input.shape)
  42. image_noise=true_input+gaussian_noise
  43. noise_tensor = torch.rand(size=true_input.shape)
  44. #添加椒盐噪声
  45. image_noise[noise_tensor<0.1]=0 #椒噪声
  46. image_noise[noise_tensor > (1-0.1)] = 1 #盐噪声
  47. #限制像素的范围在0-1之间
  48. image_noise=torch.clamp(image_noise,min=0,max=1)
  49. optimizer.zero_grad()
  50. outputs = model(image_noise)
  51. loss = criteon(outputs, true_input)
  52. loss.backward()
  53. optimizer.step()
  54. running_loss += loss.item()
  55. t.set_postfix(trainloss='{:.6f}'.format(running_loss/len(train_loader)))
  56. t.update(len(true_input))
  57. torch.save({'epoch': epoch,
  58. 'model_state_dict': model.state_dict(),
  59. 'optimizer_state_dict': optimizer.state_dict(),
  60. 'loss': running_loss/len(train_loader)
  61. }, 'model.pth')

模型的测试

在导入模型的时候经常发生上面的错误。模型在导入参数的时候不需要赋值操作。如果保存的方法是torch.load(model,'model.pth'),也就是直接保存模型的所有(包括模型的结构),在导入模型参数的时候可以使用model=torch.load('./model.pth')

  1. import numpy as np
  2. import torchvision
  3. import matplotlib.pyplot as plt
  4. from torch.utils.data import DataLoader
  5. import torch
  6. import torchvision.transforms as transforms
  7. from models import AE
  8. from data import AddPepperNoise,Gaussian_noise
  9. test_datasets = torchvision.datasets.MNIST('./', train=False, download=True)
  10. print('测试集的数量', len(test_datasets))
  11. test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)
  12. transform=transforms.Compose([
  13. transforms.ToPILImage(),
  14. Gaussian_noise(0,0.2),
  15. AddPepperNoise(0.9),
  16. transforms.ToTensor()
  17. ])
  18. model=AE()
  19. hh=torch.load('./model.pth',map_location=lambda storage,loc:storage)
  20. model.load_state_dict(hh['model_state_dict'])
  21. #错误写法
  22. # model=model.load_state_dict(hh['model_state_dict'])
  23. fig = plt.figure()
  24. for i in range(12):
  25. plt.subplot(3, 4, i + 1)
  26. img = test_datasets.train_data[i]
  27. label = test_datasets.test_labels[i]
  28. img_noise=transform(img)
  29. out=model(img_noise)
  30. out=out.squeeze()
  31. out=transforms.ToPILImage()(out)
  32. #原始图像,噪声图像,去噪图像
  33. plt.imshow(np.hstack((np.array(img),np.array(transforms.ToPILImage()(img_noise)),np.array(out))), cmap='gray')
  34. plt.title(label)
  35. plt.xticks([])
  36. plt.yticks([])
  37. plt.show()

生成随机数看看解码器能解码出什么

生成标准正太分布

  1. import matplotlib.pyplot as plt
  2. import torch
  3. import torchvision.transforms as transforms
  4. from models import AE
  5. model=AE()
  6. model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])
  7. fig = plt.figure()
  8. for i in range(12):
  9. plt.subplot(3, 4, i + 1)
  10. input=torch.randn(1,20)
  11. out=model.decoder(input)
  12. out=out.view(28,28)
  13. out=transforms.ToPILImage()(out)
  14. plt.imshow(out, cmap='gray')
  15. plt.xticks([])
  16. plt.yticks([])
  17. plt.show()

生成0-1之间的均匀分布

  1. import matplotlib.pyplot as plt
  2. import torch
  3. import torchvision.transforms as transforms
  4. from models import AE
  5. model=AE()
  6. model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])
  7. fig = plt.figure()
  8. for i in range(12):
  9. plt.subplot(3, 4, i + 1)
  10. input=torch.rand(1,20)
  11. out=model.decoder(input)
  12. out=out.view(28,28)
  13. out=transforms.ToPILImage()(out)
  14. plt.imshow(out, cmap='gray')
  15. plt.xticks([])
  16. plt.yticks([])
  17. plt.show()

可以看到随机生成的数据用解码器解码得到的数据都很乱。接下来,看看编码器编码后的数据服从什么分布。

看看编码器编码的输出服从什么分布

  1. import numpy as np
  2. import torchvision
  3. import matplotlib.pyplot as plt
  4. from torch.utils.data import DataLoader
  5. import torch
  6. import torchvision.transforms as transforms
  7. from models import AE
  8. from data import AddPepperNoise,Gaussian_noise
  9. test_datasets = torchvision.datasets.MNIST('./', train=False, download=True)
  10. print('测试集的数量', len(test_datasets))
  11. test_loader = DataLoader(test_datasets, batch_size=1, shuffle=False)
  12. transform=transforms.Compose([
  13. transforms.ToPILImage(),
  14. Gaussian_noise(0,0.2),
  15. AddPepperNoise(0.9),
  16. transforms.ToTensor()
  17. ])
  18. model=AE()
  19. hh=torch.load('./model.pth',map_location=lambda storage,loc:storage)
  20. model.load_state_dict(hh['model_state_dict'])
  21. #错误写法
  22. # model=model.load_state_dict(hh['model_state_dict'])
  23. fig = plt.figure()
  24. for i in range(1):
  25. plt.subplot(1, 1, i + 1)
  26. img = test_datasets.test_data[i]
  27. label = test_datasets.test_labels[i]
  28. img_noise=transform(img)
  29. img_noise=img_noise.view(1,-1)
  30. out=model.encoder(img_noise)
  31. print('encoder的输出',out)
  32. #正太分布检验
  33. import scipy.stats as stats
  34. print(stats.shapiro(out.detach().numpy()))
  35. plt.imshow(img, cmap='gray')
  36. plt.title(label)
  37. plt.xticks([])
  38. plt.yticks([])
  39. plt.show()
  40. print('均值',torch.mean(out))
  41. print('方差',torch.var(out))

可以看到一张图片7是服从均值为2.5,方差为8.55的正太分布的。

然后生成一些类似的分布看看效果。

  1. import matplotlib.pyplot as plt
  2. import torch
  3. import torchvision.transforms as transforms
  4. from models import AE
  5. model=AE()
  6. model.load_state_dict(torch.load('./model.pth',map_location=lambda storage,loc:storage)['model_state_dict'])
  7. fig = plt.figure()
  8. for i in range(12):
  9. plt.subplot(3, 4, i + 1)
  10. input=torch.normal(mean=2.5928,std=8.5510,size=(1,20))
  11. out=model.decoder(input)
  12. out=out.view(28,28)
  13. out=transforms.ToPILImage()(out)
  14. plt.imshow(out, cmap='gray')
  15. plt.xticks([])
  16. plt.yticks([])
  17. plt.show()

其实效果挺差的,可能是因为一张图片的分布并不能代表所有吧。


本文转载自: https://blog.csdn.net/qq_40107571/article/details/128357520
版权归原作者 南妮儿 所有, 如有侵权,请联系我们删除。

“pytorch 自编码器实现图像的降噪”的评论:

还没有评论