0


PyTorch 实现CycleGAN 风格迁移

一、前言

  1. pix2pix对训练样本要求较高,需要成对的数据集,而这种样本的获取往往需要耗费很大精力。CycleGAN恰巧解决了该问题,实现两个domain之间的转换,即只需要准备两种风格的数据集,让GAN去学习将domain X中的图片转换成domain Y的风格(不改变domain X原图中物体,仅仅实现风格转换)。
  2. 一种直观的思路是直接让G去学习domain X domain Y 以及domain Y domain X的映射关系,但这种方式会造成G生成图片的随机性太强,会使得生成的图片与输入的图片完全不相关,不仅违背了CycleGAN的目的,同时输出的结果也没有任何意义。
  3. 作者认为这种转换应该具有循环一致性,比如在语言翻译中,把一段话从中文翻译成英文,再从英文翻译回中文,意思应该是相近的,CycleGAN就是采用了这种思想。假设Ga表示Domain XDomain Y的生成器,Gb表示Domain Y Domain X 的生成器,那么让Domain X中的图片real_A通过Ga后生成的图片fake_A再通过Gb生成的rec_A应该和A是高度相似的,Domain YDomain X同理。
  4. CycleGAN中有两个生成器以及两个判别器,分别对应Domain X Domain Y 以及Domain YDomain X

二、数据集

  1. 这里我采用的是monet2photo数据集(莫奈画->真实风景照片),部分数据如下图所示。
  2. **Domain X(monet):**

  1. **Domain Y(photo):**

三、网络结构

  1. 生成器G的结构如下图所示,判别器Dpix2pix相同,网络结构pix2pix
  2. ![](https://img-blog.csdnimg.cn/a29570f091c54efd86145e8adb11d566.png?x-oss-process=image/watermark,type_d3F5LXplbmhlaQ,shadow_50,text_Q1NETiBA6L-b6Zi244GubWt5,size_20,color_FFFFFF,t_70,g_se,x_16)

四、代码

(一)net

  1. 初始化方式与源码不同。
  1. import torch.nn as nn
  2. from torchsummary import summary
  3. from collections import OrderedDict
  4. # 定义残差块
  5. class Resnet_block(nn.Module):
  6. def __init__(self, in_channels):
  7. super(Resnet_block, self).__init__()
  8. block = []
  9. for i in range(2):
  10. block += [nn.ReflectionPad2d(1),
  11. nn.Conv2d(in_channels, in_channels, 3, 1, 0),
  12. nn.InstanceNorm2d(in_channels),
  13. nn.ReLU(True) if i > 0 else nn.Identity()]
  14. self.block = nn.Sequential(*block)
  15. def forward(self, x):
  16. out = x + self.block(x)
  17. return out
  18. class Cycle_Gan_G(nn.Module):
  19. def __init__(self):
  20. super(Cycle_Gan_G, self).__init__()
  21. net_dic = OrderedDict()
  22. # 三层卷积层
  23. net_dic.update({'first layer': nn.Sequential(
  24. nn.ReflectionPad2d(3), # [3,256,256] -> [3,262,262]
  25. nn.Conv2d(3, 64, 7, 1), # [3,262,262] ->[64,256,256]
  26. nn.InstanceNorm2d(64),
  27. nn.ReLU(True)
  28. )})
  29. net_dic.update({'second_conv': nn.Sequential(
  30. nn.Conv2d(64, 128, 3, 2, 1), # [128,128,128]
  31. nn.InstanceNorm2d(128),
  32. nn.ReLU(True)
  33. )})
  34. net_dic.update({'three_conv': nn.Sequential(
  35. nn.Conv2d(128, 256, 3, 2, 1), # [256,64,64]
  36. nn.InstanceNorm2d(256),
  37. nn.ReLU(True)
  38. )})
  39. # 9层 resnet block
  40. for i in range(6):
  41. net_dic.update({'Resnet_block{}'.format(i + 1): Resnet_block(256)})
  42. # up_sample
  43. net_dic.update({'up_sample1': nn.Sequential(
  44. nn.ConvTranspose2d(256, 128, kernel_size=3, stride=2, padding=1, output_padding=1),
  45. nn.InstanceNorm2d(128), # [128,128,128]
  46. nn.ReLU(True)
  47. )})
  48. net_dic.update({'up_sample2': nn.Sequential(
  49. nn.ConvTranspose2d(128, 64, kernel_size=3, stride=2, padding=1, output_padding=1),
  50. nn.InstanceNorm2d(64), # [64,256,256]
  51. nn.ReLU(True)
  52. )})
  53. net_dic.update({'last_layer': nn.Sequential(
  54. nn.ReflectionPad2d(3),
  55. nn.Conv2d(64, 3, 7, 1),
  56. nn.Tanh()
  57. )})
  58. self.net_G = nn.Sequential(net_dic)
  59. self.init_weight()
  60. def init_weight(self):
  61. for w in self.modules():
  62. if isinstance(w, nn.Conv2d):
  63. nn.init.kaiming_normal_(w.weight, mode='fan_out')
  64. if w.bias is not None:
  65. nn.init.zeros_(w.bias)
  66. elif isinstance(w, nn.ConvTranspose2d):
  67. nn.init.kaiming_normal_(w.weight, mode='fan_in')
  68. elif isinstance(w, nn.BatchNorm2d):
  69. nn.init.ones_(w.weight)
  70. nn.init.zeros_(w.bias)
  71. def forward(self, x):
  72. out = self.net_G(x)
  73. return out
  74. class Cycle_Gan_D(nn.Module):
  75. def __init__(self):
  76. super(Cycle_Gan_D, self).__init__()
  77. # 定义基本的卷积\bn\relu
  78. def base_Conv_bn_lkrl(in_channels, out_channels, stride):
  79. if in_channels == 3:
  80. bn = nn.Identity
  81. else:
  82. bn = nn.InstanceNorm2d
  83. return nn.Sequential(
  84. nn.Conv2d(in_channels, out_channels, 4, stride, 1),
  85. bn(out_channels),
  86. nn.LeakyReLU(0.2, True)
  87. )
  88. D_dic = OrderedDict()
  89. in_channels = 3
  90. out_channels = 64
  91. for i in range(4):
  92. if i < 3:
  93. D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 2)})
  94. else:
  95. D_dic.update({'layer_{}'.format(i + 1): base_Conv_bn_lkrl(in_channels, out_channels, 1)})
  96. in_channels = out_channels
  97. out_channels *= 2
  98. D_dic.update({'last_layer': nn.Conv2d(512, 1, 4, 1, 1)}) # [batch,1,30,30]
  99. self.D_model = nn.Sequential(D_dic)
  100. def forward(self, x):
  101. return self.D_model(x)
  102. if __name__ == '__main__':
  103. # G = Cycle_Gan_G().to('cuda')
  104. # summary(G, (3, 256, 256))
  105. D = Cycle_Gan_D().to('cuda')
  106. summary(D, (3, 256, 256))

(二)train

  1. 训练过程中有一些小细节,为了减小模型振荡,提高训练的稳定性,论文中采用了buffer来暂存G生成的图片,用之前生成的图片来更新判别器。G共包含三种损失(两个方向共6部分),GAN_lossCycle_lossid_loss。其中,GAN_loss就是传统GANloss,使得输出图片尽可能真,Cycle_loss是重建的图片与原始图片之间的L1损失,id_loss是为了保证G不去随意改变图片的色调(即便判别器告诉你另外一种色调也服从Domain Y的分布,但为了仅仅改变风格不改变别的因素,因此引入了该损失)。判别器D仍然采用了PatchGAN,训练过程与pix2pix类似。
  1. import itertools
  2. from image_pool import ImagePool
  3. from torch.utils.tensorboard import SummaryWriter
  4. from cyclegan import Cycle_Gan_G, Cycle_Gan_D
  5. import argparse
  6. from mydatasets import CreateDatasets
  7. import os
  8. from torch.utils.data.dataloader import DataLoader
  9. import torch
  10. import torch.optim as optim
  11. import torch.nn as nn
  12. from utils import train_one_epoch, val
  13. def train(opt):
  14. batch = opt.batch
  15. data_path = opt.dataPath
  16. print_every = opt.every
  17. device = 'cuda' if torch.cuda.is_available() else 'cpu'
  18. epochs = opt.epoch
  19. img_size = opt.imgsize
  20. if not os.path.exists(opt.savePath):
  21. os.mkdir(opt.savePath)
  22. # 加载数据集
  23. train_datasets = CreateDatasets(data_path, img_size, mode='train')
  24. val_datasets = CreateDatasets(data_path, img_size, mode='test')
  25. train_loader = DataLoader(dataset=train_datasets, batch_size=batch, shuffle=True, num_workers=opt.numworker,
  26. drop_last=True)
  27. val_loader = DataLoader(dataset=val_datasets, batch_size=batch, shuffle=True, num_workers=opt.numworker,
  28. drop_last=True)
  29. # 实例化网络
  30. Cycle_G_A = Cycle_Gan_G().to(device)
  31. Cycle_D_A = Cycle_Gan_D().to(device)
  32. Cycle_G_B = Cycle_Gan_G().to(device)
  33. Cycle_D_B = Cycle_Gan_D().to(device)
  34. # 定义优化器和损失函数
  35. optim_G = optim.Adam(itertools.chain(Cycle_G_A.parameters(), Cycle_G_B.parameters()), lr=0.0002, betas=(0.5, 0.999))
  36. optim_D = optim.Adam(itertools.chain(Cycle_D_A.parameters(), Cycle_D_B.parameters()), lr=0.0002, betas=(0.5, 0.999))
  37. loss = nn.MSELoss()
  38. l1_loss = nn.L1Loss()
  39. start_epoch = 0
  40. A_fake_pool = ImagePool(50)
  41. B_fake_pool = ImagePool(50)
  42. # 加载预训练权重
  43. if opt.weight != '':
  44. ckpt = torch.load(opt.weight)
  45. Cycle_G_A.load_state_dict(ckpt['Ga_model'], strict=False)
  46. Cycle_G_B.load_state_dict(ckpt['Gb_model'], strict=False)
  47. Cycle_D_A.load_state_dict(ckpt['Da_model'], strict=False)
  48. Cycle_D_B.load_state_dict(ckpt['Db_model'], strict=False)
  49. start_epoch = ckpt['epoch'] + 1
  50. writer = SummaryWriter('train_logs')
  51. # 开始训练
  52. for epoch in range(start_epoch, epochs):
  53. loss_mG, loss_mD = train_one_epoch(Ga=Cycle_G_A, Da=Cycle_D_A, Gb=Cycle_G_B, Db=Cycle_D_B,
  54. train_loader=train_loader,
  55. optim_G=optim_G, optim_D=optim_D, writer=writer, loss=loss, device=device,
  56. plot_every=print_every, epoch=epoch, l1_loss=l1_loss,
  57. A_fake_pool=A_fake_pool, B_fake_pool=B_fake_pool)
  58. writer.add_scalars(main_tag='train_loss', tag_scalar_dict={
  59. 'loss_G': loss_mG,
  60. 'loss_D': loss_mD
  61. }, global_step=epoch)
  62. # 保存模型
  63. torch.save({
  64. 'Ga_model': Cycle_G_A.state_dict(),
  65. 'Gb_model': Cycle_G_B.state_dict(),
  66. 'Da_model': Cycle_D_A.state_dict(),
  67. 'Db_model': Cycle_D_B.state_dict(),
  68. 'epoch': epoch
  69. }, './weights/cycle_monent2photo.pth')
  70. # 验证集
  71. val(Ga=Cycle_G_A, Da=Cycle_D_A, Gb=Cycle_G_B, Db=Cycle_D_B, val_loader=val_loader, loss=loss, l1_loss=l1_loss,
  72. device=device, epoch=epoch)
  73. def cfg():
  74. parse = argparse.ArgumentParser()
  75. parse.add_argument('--batch', type=int, default=1)
  76. parse.add_argument('--epoch', type=int, default=100)
  77. parse.add_argument('--imgsize', type=int, default=256)
  78. parse.add_argument('--dataPath', type=str, default='../monet2photo', help='data root path')
  79. parse.add_argument('--weight', type=str, default='', help='load pre train weight')
  80. parse.add_argument('--savePath', type=str, default='./weights', help='weight save path')
  81. parse.add_argument('--numworker', type=int, default=4)
  82. parse.add_argument('--every', type=int, default=20, help='plot train result every * iters')
  83. opt = parse.parse_args()
  84. return opt
  85. if __name__ == '__main__':
  86. opt = cfg()
  87. print(opt)
  88. train(opt)
  1. import torchvision
  2. from tqdm import tqdm
  3. import torch
  4. import os
  5. def train_one_epoch(Ga, Da, Gb, Db, train_loader, optim_G, optim_D, writer, loss, device, plot_every, epoch, l1_loss,
  6. A_fake_pool, B_fake_pool):
  7. pd = tqdm(train_loader)
  8. loss_D, loss_G = 0, 0
  9. step = 0
  10. Ga.train()
  11. Da.train()
  12. Gb.train()
  13. Db.train()
  14. for idx, data in enumerate(pd):
  15. A_real = data[0].to(device)
  16. B_real = data[1].to(device)
  17. # 前向传递
  18. B_fake = Ga(A_real) # Ga生成的假B
  19. A_rec = Gb(B_fake) # Gb重构回的A
  20. A_fake = Gb(B_real) # Gb生成的假A
  21. B_rec = Ga(A_fake) # Ga重构回的B
  22. # 训练G => G包含六部分损失
  23. set_required_grad([Da, Db], requires_grad=False) # 不更新D
  24. optim_G.zero_grad()
  25. ls_G = train_G(Da=Da, Db=Db, B_fake=B_fake, loss=loss, A_fake=A_fake, l1_loss=l1_loss,
  26. A_rec=A_rec,
  27. A_real=A_real, B_rec=B_rec, B_real=B_real, Ga=Ga, Gb=Gb)
  28. ls_G.backward()
  29. optim_G.step()
  30. # 训练D
  31. set_required_grad([Da, Db], requires_grad=True)
  32. optim_D.zero_grad()
  33. A_fake_p = A_fake_pool.query(A_fake)
  34. B_fake_p = B_fake_pool.query(B_fake)
  35. ls_D = train_D(Da=Da, Db=Db, B_fake=B_fake_p, B_real=B_real, loss=loss, A_fake=A_fake_p, A_real=A_real)
  36. ls_D.backward()
  37. optim_D.step()
  38. loss_D += ls_D
  39. loss_G += ls_G
  40. pd.desc = 'train_{} G_loss: {} D_loss: {}'.format(epoch, ls_G.item(), ls_D.item())
  41. # 绘制训练结果
  42. if idx % plot_every == 0:
  43. writer.add_images(tag='epoch{}_Ga'.format(epoch), img_tensor=0.5 * (torch.cat([A_real, B_fake], 0) + 1),
  44. global_step=step)
  45. writer.add_images(tag='epoch{}_Gb'.format(epoch), img_tensor=0.5 * (torch.cat([B_real, A_fake], 0) + 1),
  46. global_step=step)
  47. step += 1
  48. mean_lsG = loss_G / len(train_loader)
  49. mean_lsD = loss_D / len(train_loader)
  50. return mean_lsG, mean_lsD
  51. @torch.no_grad()
  52. def val(Ga, Da, Gb, Db, val_loader, loss, device, l1_loss, epoch):
  53. pd = tqdm(val_loader)
  54. loss_D, loss_G = 0, 0
  55. Ga.eval()
  56. Da.eval()
  57. Gb.eval()
  58. Db.eval()
  59. all_loss = 10000
  60. for idx, item in enumerate(pd):
  61. A_real_img = item[0].to(device)
  62. B_real_img = item[1].to(device)
  63. B_fake_img = Ga(A_real_img)
  64. A_fake_img = Gb(B_real_img)
  65. A_rec = Gb(B_fake_img)
  66. B_rec = Ga(A_fake_img)
  67. # D的loss
  68. ls_D = train_D(Da=Da, Db=Db, B_fake=B_fake_img, B_real=B_real_img, loss=loss, A_fake=A_fake_img,
  69. A_real=A_real_img)
  70. # G的loss
  71. ls_G = train_G(Da=Da, Db=Db, B_fake=B_fake_img, loss=loss, A_fake=A_fake_img, l1_loss=l1_loss,
  72. A_rec=A_rec,
  73. A_real=A_real_img, B_rec=B_rec, B_real=B_real_img, Ga=Ga, Gb=Gb)
  74. loss_G += ls_G
  75. loss_D += ls_D
  76. pd.desc = 'val_{}: G_loss:{} D_Loss:{}'.format(epoch, ls_G.item(), ls_D.item())
  77. # 保存最好的结果
  78. all_ls = ls_G + ls_D
  79. if all_ls < all_loss:
  80. all_loss = all_ls
  81. best_image = torch.cat([A_real_img, B_fake_img, B_real_img, A_fake_img], 0)
  82. result_img = (best_image + 1) * 0.5
  83. if not os.path.exists('./results'):
  84. os.mkdir('./results')
  85. torchvision.utils.save_image(result_img, './results/val_epoch{}_cycle.jpg'.format(epoch))
  86. def set_required_grad(nets, requires_grad=False):
  87. if not isinstance(nets, list):
  88. nets = [nets]
  89. for net in nets:
  90. if net is not None:
  91. for params in net.parameters():
  92. params.requires_grad = requires_grad
  93. def train_G(Da, Db, B_fake, loss, A_fake, l1_loss, A_rec, A_real, B_rec, B_real, Ga, Gb):
  94. # GAN loss
  95. Da_out_fake = Da(B_fake)
  96. Ga_gan_loss = loss(Da_out_fake, torch.ones(Da_out_fake.size()).cuda())
  97. Db_out_fake = Db(A_fake)
  98. Gb_gan_loss = loss(Db_out_fake, torch.ones(Db_out_fake.size()).cuda())
  99. # Cycle loss
  100. Cycle_A_loss = l1_loss(A_rec, A_real) * 10
  101. Cycle_B_loss = l1_loss(B_rec, B_real) * 10
  102. # identity loss
  103. Ga_id_out = Ga(B_real)
  104. Gb_id_out = Gb(A_real)
  105. Ga_id_loss = l1_loss(Ga_id_out, B_real) * 10 * 0.5
  106. Gb_id_loss = l1_loss(Gb_id_out, A_real) * 10 * 0.5
  107. # G的总损失
  108. ls_G = Ga_gan_loss + Gb_gan_loss + Cycle_A_loss + Cycle_B_loss + Ga_id_loss + Gb_id_loss
  109. return ls_G
  110. def train_D(Da, Db, B_fake, B_real, loss, A_fake, A_real):
  111. # Da的loss
  112. Da_fake_out = Da(B_fake.detach()).squeeze()
  113. Da_real_out = Da(B_real).squeeze()
  114. ls_Da1 = loss(Da_fake_out, torch.zeros(Da_fake_out.size()).cuda())
  115. ls_Da2 = loss(Da_real_out, torch.ones(Da_real_out.size()).cuda())
  116. ls_Da = (ls_Da1 + ls_Da2) * 0.5
  117. # Db的loss
  118. Db_fake_out = Db(A_fake.detach()).squeeze()
  119. Db_real_out = Db(A_real.detach()).squeeze()
  120. ls_Db1 = loss(Db_fake_out, torch.zeros(Db_fake_out.size()).cuda())
  121. ls_Db2 = loss(Db_real_out, torch.ones(Db_real_out.size()).cuda())
  122. ls_Db = (ls_Db1 + ls_Db2) * 0.5
  123. # D的总损失
  124. ls_D = ls_Da + ls_Db
  125. return ls_D

(三)test

  1. from cyclegan import Cycle_Gan_G
  2. import torch
  3. import torchvision.transforms as transform
  4. import matplotlib.pyplot as plt
  5. import cv2
  6. from PIL import Image
  7. def test(img_path):
  8. if img_path.endswith('.png'):
  9. img = cv2.imread(img_path)
  10. img = img[:, :, ::-1]
  11. else:
  12. img = Image.open(img_path)
  13. transforms = transform.Compose([
  14. transform.ToTensor(),
  15. transform.Resize((256, 256)),
  16. transform.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
  17. ])
  18. img = transforms(img.copy())
  19. img = img[None].to('cuda') # [1,3,128,128]
  20. # 实例化网络
  21. Gb = Cycle_Gan_G().to('cuda')
  22. # 加载预训练权重
  23. ckpt = torch.load('weights/cycle_monent2photo.pth')
  24. Gb.load_state_dict(ckpt['Gb_model'], strict=False)
  25. Gb.eval()
  26. out = Gb(img)[0]
  27. out = out.permute(1, 2, 0)
  28. out = (0.5 * (out + 1)).cpu().detach().numpy()
  29. plt.figure()
  30. plt.imshow(out)
  31. plt.show()
  32. if __name__ == '__main__':
  33. test('123.jpg')

五、结果

(一)loss

(二)训练可视化

  1. 这里我挑选了一部分训练结果和验证结果。
  2. 训练集上monet -> photo

  1. 训练集上photo-> monet

  1. 验证集上结果(左边为monet -> photo,右边为photo-> monet )

(三)测试结果

  1. 下图为photomonet的结果

六、完整代码

  1. 数据集:百度网盘 请输入提取码 提取码:s3e3
  2. 代码:百度网盘 请输入提取码 提取码:t0d5

本文转载自: https://blog.csdn.net/qq_57886603/article/details/122136248
版权归原作者 进阶のmky 所有, 如有侵权,请联系我们删除。

“PyTorch 实现CycleGAN 风格迁移”的评论:

还没有评论