0


三、计算机视觉_09GAN对抗学习案例

1、对抗学习基础知识

1.1 定义

对抗学习(Adversarial Learning)是一种机器学习范式,它涉及到两个或多个模型在相互竞争的环境中进行训练,以提高各自的性能,这种学习方式的核心思想是通过对抗过程来激发模型的潜力,使它们在面对对手的挑战时不断进化和改进

1.2 GAN

GAN(Generative Adversarial Network,即:生成对抗网络,复数是GANs)是对抗学习研究中最突出的一项技术,它由两部分组成:一个生成器(G)和一个判别器(D),生成器的任务是学会制造假的数据,判别器的任务是区分真实数据和假数据

1.3 核心思想

对抗学习的模型不重要,思想很重要:

  • 对抗过程:在对抗学习中,通常有两个模型【生成器(Generator)和鉴别器(Discriminator)】在相互竞争,生成器的目标是生成尽可能接近真实数据的假数据,而鉴别器的目标是区分真实数据和生成器生成的假数据
  • 共同成长:对抗学习的特点之一是两个模型在对抗中共同进步,随着生成器生成越来越逼真的数据,鉴别器也必须变得更加聪明以识别这些假数据;反之亦然,随着鉴别器变得更加敏锐,生成器也必须提高其生成质量
  • 非零和博弈:与零和博弈不同,在对抗学习中,生成器和鉴别器并不是一方赢另一方输的关系,它们通过竞争相互促进,最终达到共同提高的目的

2、对抗学习的应用

2.1 应用到图像生成任务

问题:给定一堆原始图像,要求训练一个模型,可以模仿给定的图像,来生成类似的图像

采用对抗学习的方法进行解决问题:

  • 生成器:这个模型负责生成图像,它从随机噪声开始,通过学习原始图像的特征,逐渐生成越来越逼真的图像
  • 鉴别器:这个模型负责判断一个图像是真实的还是由生成器生成的,它通过分析图像的特征来做出判断
  • 训练过程:在训练过程中,生成器和鉴别器交替更新,生成器试图欺骗鉴别器,而鉴别器则试图不被欺骗,这种对抗推动了两个模型的性能提升

2.2 应用到其他任务

  • 模型融合:将对抗学习的思想应用到其他模型设计中,可以提高这些模型的泛化能力和鲁棒性(例如,在半监督学习中,可以利用对抗学习的思想来增强模型对未标记数据的学习能力)
  • 多任务学习:在多任务学习中,可以设计不同的生成器和鉴别器来处理不同的任务,通过对抗学习来提高各个任务的性能
  • 数据增强:对抗学习可以作为一种数据增强技术,通过生成器生成更多的训练样本,帮助模型更好地学习

3、手写数字对抗学习案例

3.1 需求概述

现存在一批手写数字的数据集(类似下图),需要通过GAN来模仿手写数字,使得生成器能生成与手写数字非常相似的图像,且不易被鉴别器识别

3.2 代码实现

Step1: 导包及设备检测

import torch
from torch import nn
from torchvision import datasets
from torchvision import transforms
from torch.utils.data import DataLoader
from matplotlib import pyplot as plt
from matplotlib import gridspec
import os
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
from IPython import display

# 设备检测
device = "cuda" if torch.cuda.is_available() else "cpu"

Step2: 数据集的加载及预处理

# 图像预处理工具
trans = transforms.Compose(transforms=[
    # [0, 255] [H, W, C] --> [0, 1] [C, H, W]
    transforms.ToTensor(),
    # [0, 1] --> [-1, 1]
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# 下载数据集
# datasets.MNIST:这是PyTorch的torchvision.datasets模块中的一个类,用于加载MNIST数据集
# root="data":指定数据集下载和缓存的根目录为当前代码工作目录下的data文件夹中
# train=True:指定是否加载训练集,设置为True表示加载训练集(如果设置为False,则加载测试集)
# transform=trans:指定之前定义好的图像预处理工具
# download=True:指定如果数据集不在本地缓存中,是否自动下载数据集(设置为True表示如果数据集不存在,则自动下载)
data = datasets.MNIST(root="data", train=True, transform=trans, download=True)

# 加载数据集
# dataset=data:指定数据集为上面下载好的MNIST(存放在data目录下)
# batch_size=128:指定了每个数据批次的大小(每批次128个图像)
# shuffle=True:指定是否在每个epoch开始时打乱数据(设置为True意味着数据将在每个epoch开始时被随机打乱,这有助于提高模型训练时的泛化能力)
dataloader = DataLoader(dataset=data, batch_size=128, shuffle=True)

# 遍历一批次数据,看看数据的形状以及样例
for batch_X, batch_y in dataloader:
    print(batch_X.shape)
    print(batch_y.shape)
    plt.imshow(batch_X[0][0])  
    break

Step3: 定义生成器和鉴别器

"""
    定义生成器(制作假图像的网络)
        - 输入:一个向量(又称Z、noise、噪声,随机数, 比如:128维)
        - 输出:一个图像([1, 28, 28]),一个向量(28 * 28)
"""
class Generator(nn.Module):
    def __init__(self, in_features=128, out_features=28 * 28):
        super().__init__()
        # 用三层全连接(中间不管多少维度,最终从128到28*28即可)
        self.linear1 = nn.Linear(in_features=in_features, out_features=256)
        self.linear2 = nn.Linear(in_features=256, out_features=512)
        self.linear3 = nn.Linear(in_features=512, out_features=out_features)
    def forward(self, x):
        # ReLU函数定义为f(x) = max(0, x),即当输入x大于0时,输出x;当输入x小于0时,输出0
        # Tanh函数定义为f(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x)),它是双曲正切函数,输出范围在-1到1之间
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        x = torch.relu(x)
        x = self.linear3(x)
        x = torch.tanh(x)
        return x

"""
    鉴别器(判断真假)
        - 输入:一个图像
        - 输出:图像是真的还是假的(二分类问题,输出真或假这两个类别)
"""
class Discriminator(nn.Module):
    def __init__(self, in_features=28 * 28, out_features=2):
        super().__init__()
        # 用三层全连接(中间不管多少维度,最终从28*28到2即可)
        self.linear1 = nn.Linear(in_features=in_features, out_features=512)
        self.linear2 = nn.Linear(in_features=512, out_features=256)
        self.linear3 = nn.Linear(in_features=256, out_features=32)
        self.linear4 = nn.Linear(in_features=32, out_features=out_features)
    def forward(self, x):
        x = self.linear1(x)
        x = torch.relu(x)
        x = self.linear2(x)
        x = torch.relu(x)
        x = self.linear3(x)
        x = torch.relu(x)
        x = self.linear4(x)
        return x

Step4: 准备训练

# 设置训练轮次
epochs = 1000

# 实例化生成器和鉴别器模型,并进行数据搬家
generator = Generator(in_features=128, out_features=28 * 28)
generator.to(device=device) 
discriminator = Discriminator(in_features=28 * 28, out_features=2)
discriminator.to(device=device)

# 优化器:两个网络,两个优化器,并非同时进行,而是交替进行
g_optimizer = torch.optim.Adam(params=generator.parameters(), lr=1e-4)
d_optimizer = torch.optim.Adam(params=discriminator.parameters(), lr=1e-4)

# 定义损失函数
loss_fn = nn.CrossEntropyLoss()

def get_real_data_labels(size):
    """
        获取真实数据的标签
        真实图像来自于待模仿的对象,也就是手写数字,标签为1
    """
    labels = torch.ones(size, device=device, dtype=torch.long)
    return labels

def get_fake_data_labels(size):
    """
        获取假数据的标签
        假图像来自于生成器生成的结果,也就是输入随机数之后生成的假图像,标签为0
    """
    labels = torch.zeros(size, device=device, dtype=torch.long)
    return labels

def get_noise(size):
    """
        根据随机数获取噪声
    """
    X = torch.randn(size, 128, device=device)
    return X

Step5: 开始训练

"""
    训练过程:
    GANs的训练目标是使得判别器无法区分真实数据和生成器生成的数据,这意味着判别器应该对真实数据输出高概率(接近1),对假数据输出低概率(接近0)
    因此,GANs的训练过程本质上是两个网络(判别器和生成器)之间的对抗过程,而不是单纯地最小化损失函数
"""
for epoch in range(1, epochs + 1):
    # 设置训练模式
    generator.train()
    discriminator.train()
    print(f"当前正在进行第{epoch}轮训练,训练完成之后会显示图像结果")
    # 每一轮都是逐批次训练
    for batch_idx, (batch_X, batch_y) in enumerate(dataloader):
        """
            根据当前批次的样本获取真实数据和假数据
        """
        # 1、获取真实图像 
        # view()方法用于改变张量的形状而不改变其数据,to(device=device)方法用于数据搬家
        # batch_X.size(0)是获取batch_X的第一个维度数据,并保持不变;-1代表除了第一个不变的维度之外的剩下所有维度
        # 即:[N, C, H, W] --> [N, C * H * W]
        real_data = batch_X.view(batch_X.size(0), -1).to(device=device)
        # 2、给生成器准备的输入噪声(噪声的数量即为real_data的数量)
        num_real_data = real_data.size(0)
        noise_with_real = get_noise(num_real_data).to(device=device)
        # 3、通过生成器生成假数据,并用detach()方法取消这个假数据张量的梯度跟踪和权重计算
        fake_data = generator(noise_with_real).detach()

        """
            优化一步鉴别器
            鉴别器需要让假数据的预测结果接近0,真数据的预测结果接近1,防止造假
        """
        # 1、清空鉴别器的梯度
        d_optimizer.zero_grad()
        # 2、把真实数据交给鉴别器去学习
        real_pred = discriminator(real_data)
        # 3、计算损失(真实数据输出的预测结果与1之间的差距)
        real_loss = loss_fn(real_pred, get_real_data_labels(real_data.size(0)))
        # 4、反向传播,计算real_loss关于判别器参数的梯度
        real_loss.backward()
        # 5、把假数据交给鉴别器去鉴别
        fake_pred = discriminator(fake_data)
        # 6、计算损失(假数据输出的预测结果与0之间的差距)
        fake_loss = loss_fn(fake_pred, get_fake_data_labels(fake_data.size(0)))
        # 7、反向传播,计算fake_loss关于判别器参数的梯度
        fake_loss.backward()
        # 8、鉴别器优化一步(根据计算出的梯度更新判别器的参数)
        d_optimizer.step()

        """
            训练一步生成器
            生成器需要让鉴别器预测假数据的结果越来越接近1(也就是让鉴别器觉得假数据是真数据)
        """
        # 1、重新获取一批噪声数据
        noise_with_real_new = get_noise(num_real_data).to(device=device)
        # 2、把新的噪声交给生成器,从而获取新的假数据
        fake_data_new = generator(noise_with_real_new)
        # 3、清空生成器的梯度
        g_optimizer.zero_grad()
        # 4、临时冻结鉴别器
        # discriminator.parameters()返回判别器网络中所有的参数(即权重和偏置),这些参数是可训练的,并且在反向传播中会计算梯度
        # requires_grad=False表示取消梯度计算
        for param in discriminator.parameters():
            param.requires_grad=False
        # 5、根据新的假数据,让鉴别器预测出结果
        d_pred = discriminator(fake_data_new)
        # 6、解冻结鉴别器
        for param in discriminator.parameters():
            param.requires_grad=True
        # 7、互相博弈的关键:计算损失(假数据输出的预测结果与1之间的差距)
        g_loss = loss_fn(d_pred, get_real_data_labels(d_pred.size(0)))
        # 8、反向传播,计算g_loss关于判别器参数的梯度
        g_loss.backward()
        # 9、生成器优化一步(根据计算出的梯度更新判别器的参数)
        g_optimizer.step()

    """
        获取一批测试噪声
        num_test_samples = 16
        每训练一轮,看看训练结果
    """
    num_test_samples = 16
    test_noise = get_noise(num_test_samples)
    generator.eval()
    with torch.no_grad():
        # 正向推理
        img_pred = generator(test_noise)
        img_pred = img_pred.view(img_pred.size(0), 28, 28).cpu().data
        # 清空画布
        display.clear_output(wait=True)
        # 设置画图的大小
        fig = plt.figure(1, figsize=(12, 8)) 
        # 划分为 4 x 4 的 网格
        gs = gridspec.GridSpec(4, 4)
        # 遍历每一个单元格
        for i in range(4):
            for j in range(4):
                # 取每一个图
                X = img_pred[i * 4 + j, :, :]
                # 将对应的图放入对应的网格中
                ax = fig.add_subplot(gs[i, j])
                # 在对应的网格中显示图像
                ax.matshow(X, cmap=plt.get_cmap("Greys"))
                # 将图像坐标 x 轴和 y 轴的刻度设置为空
                ax.set_xticks(())
                ax.set_yticks(())
        plt.show()

3.3 效果展示

1000轮训练之后,效果如下(只显示手写数字1了)

PS:数字1笔画比较简单,容易模仿,所以最后生成器和鉴别器一直在数字1上面做对抗

4、附页

在图像生成方法中,除了GAN,还有一种经典的方法,即:Diffusion(扩散模型)

扩散模型通过一系列逐步增加噪声的过程,将数据分布转换成一个简单的先验分布(如高斯分布),再训练一个模型来逆转这个过程,即:从噪声中逐渐恢复出原始的数据分布

扩散模型包括两个过程:

  • 前向过程(forward process):前向过程又称为扩散过程,从原始图片逐步加噪至一组纯噪声
  • 反向过程(reverse process):反向过程则是将一组随机噪声还原为输入的过程,需要学习一个去噪过程

Stable Diffusion 是一个基于扩散模型的图像生成产品,它能够根据文本提示(Prompt)生成高质量的图像(Prompt思想是指通过输入提示词、提醒词或输入人类的小的特色来指导图像的生成),Stable Diffusion因其高质量的图像生成能力而被广泛使用

Stable Diffusion相关的典型网站和服务:

  • jimeng.jianying.com:剪影(字节旗下)的AI视频平台,提供中文提示词进行创作,AI能够准确把握用户的需求,将抽象的思路转化为具体的视觉效果
  • yige.baidu.com:百度提供的一个AI绘画平台,用户可以在这里进行图像生成
  • midjourney:国外的一个图像生成服务的平台,用户可以通过它来创造新的图像内容

本文转载自: https://blog.csdn.net/weixin_43767064/article/details/144186147
版权归原作者 学不会lostfound 所有, 如有侵权,请联系我们删除。

“三、计算机视觉_09GAN对抗学习案例”的评论:

还没有评论