0


DDPM代码实现详解

文章目录


一、代码框架

请添加图片描述


一、Diffusion.py

这个代码文件用于实现扩散模型DDPM的理论公式推导部分。

# 提取出当前时间步tdefextract(v, t, x_shape):"""
    Extract some coefficients at specified timesteps, then reshape to
    [batch_size, 1, 1, 1, 1, ...] for broadcasting purposes.
    """
    device = t.deviced
    out = torch.gather(v, index=t, dim=0).float().to(device)return out.view([t.shape[0]]+[1]*(len(x_shape)-1))

v是包含随时间变化的系数的tensor(比如alphas_bar),对每个时间步t,从中提取出对应的值。

训练过程

GaussianDiffusionTrainer

负责模型的训练,即学习在给定的时间步t预测噪声。

classGaussianDiffusionTrainer(nn.Module):

计算各个常数系数:

def__init__(self, model, beta_1, beta_T, T):super().__init__()

        self.model = model
        self.T = T

        # betas缓冲区,包含了从1到T的beta的值
        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas =1.- self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)# alpha累积# calculations for diffusion q(x_t | x_{t-1}) and others
        self.register_buffer('sqrt_alphas_bar', torch.sqrt(alphas_bar))
        self.register_buffer('sqrt_one_minus_alphas_bar', torch.sqrt(1.- alphas_bar))

在构造函数中,模型、β值范围以及总时间步数 T 被传递进来。这里还计算了 α 值(1-β),累积的 α 值(α_bar),以及与扩散过程相关的根号下的 α_bar 和 1-α_bar 的值,用于后续的计算。

训练:
根据论文,优化目标最终可以简化为缩小预测噪声和实际噪声的差值,就可以实现对前一步状态的均值的最优预测模型。
请添加图片描述

defforward(self, x_0):"""
        Algorithm 1.
        """
        t = torch.randint(self.T, size=(x_0.shape[0],), device=x_0.device)# 随机选择一个时间步t
        noise = torch.randn_like(x_0)
        x_t =(
            extract(self.sqrt_alphas_bar, t, x_0.shape)* x_0 +
            extract(self.sqrt_one_minus_alphas_bar, t, x_0.shape)* noise)# 预测xt处对应噪声和实际噪声对比的损失函数
        loss = F.mse_loss(self.model(x_t, t), noise, reduction='none')return loss

这一步是根据如下公式计算出x_t,并计算出在t时间步时预测噪声和实际噪声的MSE损失:请添加图片描述

采样过程

classGaussianDiffusionSampler(nn.Module):

初始化:

def__init__(self, model, beta_1, beta_T, T):super().__init__()

        self.model = model
        self.T = T

        self.register_buffer('betas', torch.linspace(beta_1, beta_T, T).double())
        alphas =1.- self.betas
        alphas_bar = torch.cumprod(alphas, dim=0)
        alphas_bar_prev = F.pad(alphas_bar,[1,0], value=1)[:T]

        self.register_buffer('coeff1', torch.sqrt(1./ alphas))
        self.register_buffer('coeff2', self.coeff1 *(1.- alphas)/ torch.sqrt(1.- alphas_bar))

        self.register_buffer('posterior_var', self.betas *(1.- alphas_bar_prev)/(1.- alphas_bar))

计算逆扩散过程中涉及的系数coeff1(保留原始信息占比)、coeff2(扰动噪声占比)和posterior_var(后验方差):
均值(只和x_t和z_t有关):请添加图片描述
方差:
请添加图片描述
预测:
根据后验条件概率和前面给出的均值和方差计算公式,对前一步状态进行计算:
请添加图片描述
均值计算公式:

# 预测先前xt的均值defpredict_xt_prev_mean_from_eps(self, x_t, t, eps):assert x_t.shape == eps.shape
        return(
            extract(self.coeff1, t, x_t.shape)* x_t -
            extract(self.coeff2, t, x_t.shape)* eps
        )

得到均值和方差,其中的噪声项eps用模型预测:

defp_mean_variance(self, x_t, t):# below: only log_variance is used in the KL computations
        var = torch.cat([self.posterior_var[1:2], self.betas[1:]])
        var = extract(var, t, x_t.shape)

        eps = self.model(x_t, t)
        xt_prev_mean = self.predict_xt_prev_mean_from_eps(x_t, t, eps=eps)return xt_prev_mean, var

采样:
请添加图片描述

defforward(self, x_T):"""
        Algorithm 2.
        """
        x_t = x_T
        for time_step inreversed(range(self.T)):print(time_step)
            t = x_t.new_ones([x_T.shape[0],], dtype=torch.long)* time_step
            mean, var= self.p_mean_variance(x_t=x_t, t=t)# 不是最后一步加噪 no noise when t == 0if time_step >0:
                noise = torch.randn_like(x_t)else:
                noise =0
            x_t = mean + torch.sqrt(var)* noise
            assert torch.isnan(x_t).int().sum()==0,"nan in tensor."
        x_0 = x_t
        return torch.clip(x_0,-1,1)

对每个时间步,用刚刚训练好的模型预测前一状态的均值和方差,并添加新的随机噪声。

二、Model.py

这个代码文件用于构造基于DDPM的U-Net架构。

# 激活函数classSwish(nn.Module):defforward(self, x):return x * torch.sigmoid(x)

时间嵌入

classTimeEmbedding(nn.Module):

初始化:

def__init__(self, T, d_model, dim):assert d_model %2==0# 检查维度是否为偶数,每一维对应一对正余弦值super().__init__()
        emb = torch.arange(0, d_model, step=2)/ d_model * math.log(10000)# 频率向量
        emb = torch.exp(-emb)
        pos = torch.arange(T).float()
        emb = pos[:,None]* emb[None,:]# 时间嵌入的基频矩阵assertlist(emb.shape)==[T, d_model //2]
        emb = torch.stack([torch.sin(emb), torch.cos(emb)], dim=-1)assertlist(emb.shape)==[T, d_model //2,2]
        emb = emb.view(T, d_model)

        self.timembedding = nn.Sequential(
            nn.Embedding.from_pretrained(emb),
            nn.Linear(d_model, dim),
            Swish(),
            nn.Linear(dim, dim),)
        self.initialize()

通过位置编码(Positional Encoding)机制,给定序列中的每个位置添加时间信息。

  1. 检查维度为偶数,便于对每个维度分配一个正余弦值对。
  2. 频率向量计算,用于后续生成正弦和余弦波形。
  3. 通过将位置索引pos[:, None]和频率向量emb[None, :]相乘生成基频矩阵。
  4. 分别计算正余弦函数值,将它们堆叠起来形成时间嵌入矩阵emb。
  5. 构建时间嵌入层。
definitialize(self):for module in self.modules():ifisinstance(module, nn.Linear):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)defforward(self, t):
        emb = self.timembedding(t)return emb

initialize初始化模型中的权重,forward规定模型从时间嵌入层获取时间信息。

上采样和下采样:

classDownSample(nn.Module):def__init__(self, in_ch):super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch,3, stride=2, padding=1)
        self.initialize()definitialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)defforward(self, x, temb):
        x = self.main(x)return x

classUpSample(nn.Module):def__init__(self, in_ch):super().__init__()
        self.main = nn.Conv2d(in_ch, in_ch,3, stride=1, padding=1)
        self.initialize()definitialize(self):
        init.xavier_uniform_(self.main.weight)
        init.zeros_(self.main.bias)defforward(self, x, temb):
        _, _, H, W = x.shape
        x = F.interpolate(
            x, scale_factor=2, mode='nearest')
        x = self.main(x)return x

在下采样中,stride=2表明特征图片的空间尺寸减半。
在上采样中,用F.interpolat双线性插值使空间尺寸扩大两倍。
从而输入和输出的维度不变。

自注意力机制:

classAttnBlock(nn.Module):def__init__(self, in_ch):super().__init__()
        self.group_norm = nn.GroupNorm(32, in_ch)
        self.proj_q = nn.Conv2d(in_ch, in_ch,1, stride=1, padding=0)
        self.proj_k = nn.Conv2d(in_ch, in_ch,1, stride=1, padding=0)
        self.proj_v = nn.Conv2d(in_ch, in_ch,1, stride=1, padding=0)
        self.proj = nn.Conv2d(in_ch, in_ch,1, stride=1, padding=0)
        self.initialize()definitialize(self):for module in[self.proj_q, self.proj_k, self.proj_v, self.proj]:
            init.xavier_uniform_(module.weight)
            init.zeros_(module.bias)
        init.xavier_uniform_(self.proj.weight, gain=1e-5)defforward(self, x):
        B, C, H, W = x.shape
        h = self.group_norm(x)
        q = self.proj_q(h)
        k = self.proj_k(h)
        v = self.proj_v(h)

        q = q.permute(0,2,3,1).view(B, H * W, C)
        k = k.view(B, C, H * W)
        w = torch.bmm(q, k)*(int(C)**(-0.5))assertlist(w.shape)==[B, H * W, H * W]
        w = F.softmax(w, dim=-1)

        v = v.permute(0,2,3,1).view(B, H * W, C)
        h = torch.bmm(w, v)assertlist(h.shape)==[B, H * W, C]
        h = h.view(B, H, W, C).permute(0,3,1,2)
        h = self.proj(h)return x + h

这个注意力模块的主要功能是在输入特征图上应用自我注意力机制,从而在不改变输入和输出尺寸的情况下增强模型的表达能力。通过这种注意力机制,模型可以学会关注输入中的重要部分,忽略不相关的信息,提高整体性能。

残差块:

classResBlock(nn.Module):def__init__(self, in_ch, out_ch, tdim, dropout, attn=False):super().__init__()
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_ch),
            Swish(),
            nn.Conv2d(in_ch, out_ch,3, stride=1, padding=1),)
        self.temb_proj = nn.Sequential(
            Swish(),
            nn.Linear(tdim, out_ch),)
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_ch),
            Swish(),
            nn.Dropout(dropout),
            nn.Conv2d(out_ch, out_ch,3, stride=1, padding=1),)if in_ch != out_ch:
            self.shortcut = nn.Conv2d(in_ch, out_ch,1, stride=1, padding=0)else:
            self.shortcut = nn.Identity()if attn:
            self.attn = AttnBlock(out_ch)else:
            self.attn = nn.Identity()
        self.initialize()definitialize(self):for module in self.modules():ifisinstance(module,(nn.Conv2d, nn.Linear)):
                init.xavier_uniform_(module.weight)
                init.zeros_(module.bias)
        init.xavier_uniform_(self.block2[-1].weight, gain=1e-5)defforward(self, x, temb):
        h = self.block1(x)
        h += self.temb_proj(temb)[:,:,None,None]
        h = self.block2(h)

        h = h + self.shortcut(x)
        h = self.attn(h)return h

U-net架构:

classUNet(nn.Module):def__init__(self, T, ch, ch_mult, attn, num_res_blocks, dropout):super().__init__()assertall([i <len(ch_mult)for i in attn]),'attn index out of bound'
        tdim = ch *4
        self.time_embedding = TimeEmbedding(T, ch, tdim)

        self.head = nn.Conv2d(3, ch, kernel_size=3, stride=1, padding=1)

下采样:

        self.downblocks = nn.ModuleList()
        chs =[ch]# record output channel when dowmsample for upsample
        now_ch = ch
        for i, mult inenumerate(ch_mult):
            out_ch = ch * mult
            for _ inrange(num_res_blocks):
                self.downblocks.append(ResBlock(
                    in_ch=now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
                chs.append(now_ch)if i !=len(ch_mult)-1:
                self.downblocks.append(DownSample(now_ch))
                chs.append(now_ch)
  • 由于上采样需要匹配下采样的通道数,因此用chs=[ch]记录下采样的输出通道数;
  • now_ch记录当前通道数,初始化为ch;
  • 对每个通道放大倍数mult,添加若干残差块,用于解决梯度消失问题,并更新当前通道数为残差块的输出通道数;
  • 如果不是最后一次操作,继续添加一个下采样层,用于减小特征图尺寸来简化计算,但不改变通道数。

中间块:

        self.middleblocks = nn.ModuleList([
            ResBlock(now_ch, now_ch, tdim, dropout, attn=True),
            ResBlock(now_ch, now_ch, tdim, dropout, attn=False),])

在上下采样层之间添加中间块,加强学习能力。

上采样:

        self.upblocks = nn.ModuleList()for i, mult inreversed(list(enumerate(ch_mult))):
            out_ch = ch * mult
            for _ inrange(num_res_blocks +1):
                self.upblocks.append(ResBlock(
                    in_ch=chs.pop()+ now_ch, out_ch=out_ch, tdim=tdim,
                    dropout=dropout, attn=(i in attn)))
                now_ch = out_ch
            if i !=0:
                self.upblocks.append(UpSample(now_ch))assertlen(chs)==0

和下采样过程类似,添加残差块和上采样层

输出层:

        self.tail = nn.Sequential(
            nn.GroupNorm(32, now_ch),
            Swish(),
            nn.Conv2d(now_ch,3,3, stride=1, padding=1))
        self.initialize()definitialize(self):
        init.xavier_uniform_(self.head.weight)
        init.zeros_(self.head.bias)
        init.xavier_uniform_(self.tail[-1].weight, gain=1e-5)
        init.zeros_(self.tail[-1].bias)
  • nn.GroupNorm(32, now_ch)将输入通道分为32组进行归一化;
  • Swish是激活函数;
  • 输出通道数为3,即RGB通道图像,输出特征图的尺寸不变;
  • 初始化权重和偏置。

基于U-net定义前向传播函数:

defforward(self, x, t):# Timestep embedding
        temb = self.time_embedding(t)# Downsampling
        h = self.head(x)
        hs =[h]for layer in self.downblocks:
            h = layer(h, temb)
            hs.append(h)# Middlefor layer in self.middleblocks:
            h = layer(h, temb)# Upsamplingfor layer in self.upblocks:ifisinstance(layer, ResBlock):
                h = torch.cat([h, hs.pop()], dim=1)
            h = layer(h, temb)
        h = self.tail(h)assertlen(hs)==0return h

在主程序汇总实例化测试模型:

if __name__ =='__main__':
    batch_size =8
    model = UNet(
        T=1000, ch=128, ch_mult=[1,2,2,2], attn=[1],
        num_res_blocks=2, dropout=0.1)
    x = torch.randn(batch_size,3,32,32)
    t = torch.randint(1000,(batch_size,))
    y = model(x, t)print(y.shape)

三、Train.py

下面是基于CIFAR-10数据集对DDPM进行训练和评估的代码。

训练

deftrain(modelConfig: Dict):
    device = torch.device(modelConfig["device"])# dataset
    dataset = CIFAR10(
        root='./CIFAR10', train=True, download=True,
        transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5)),]))
    dataloader = DataLoader(
        dataset, batch_size=modelConfig["batch_size"], shuffle=True, num_workers=4, drop_last=True, pin_memory=True)

选取设备为CPU或GPU,加载数据集CIFAR-10为dataset,并设置dataloader用于加载数据集。

# model setup
    net_model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=modelConfig["dropout"]).to(device)if modelConfig["training_load_weight"]isnotNone:
        net_model.load_state_dict(torch.load(os.path.join(
            modelConfig["save_weight_dir"], modelConfig["training_load_weight"]), map_location=device))
    optimizer = torch.optim.AdamW(
        net_model.parameters(), lr=modelConfig["lr"], weight_decay=1e-4)
    cosineScheduler = optim.lr_scheduler.CosineAnnealingLR(
        optimizer=optimizer, T_max=modelConfig["epoch"], eta_min=0, last_epoch=-1)
    warmUpScheduler = GradualWarmupScheduler(
        optimizer=optimizer, multiplier=modelConfig["multiplier"], warm_epoch=modelConfig["epoch"]//10, after_scheduler=cosineScheduler)
    trainer = GaussianDiffusionTrainer(
        net_model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)

实例化UNet模型、优化器、学习率、模型训练器。

# start trainingfor e inrange(modelConfig["epoch"]):with tqdm(dataloader, dynamic_ncols=True)as tqdmDataLoader:for images, labels in tqdmDataLoader:# train
                optimizer.zero_grad()
                x_0 = images.to(device)
                loss = trainer(x_0).sum()/1000.
                loss.backward()
                torch.nn.utils.clip_grad_norm_(
                    net_model.parameters(), modelConfig["grad_clip"])
                optimizer.step()
                tqdmDataLoader.set_postfix(ordered_dict={"epoch": e,"loss: ": loss.item(),"img shape: ": x_0.shape,"LR": optimizer.state_dict()['param_groups'][0]["lr"]})
        warmUpScheduler.step()
        torch.save(net_model.state_dict(), os.path.join(
            modelConfig["save_weight_dir"],'ckpt_'+str(e)+"_.pt"))

训练步骤:(对每个训练周期,即一个epoch)

  • 清除梯度
  • 将输入数据移动到指定设备
  • 计算损失
  • 梯度裁剪,防止梯度爆炸
  • 更新模型参数
  • 更新学习率,保存当前权重

评估

defeval(modelConfig: Dict):# load model and evaluatewith torch.no_grad():
        device = torch.device(modelConfig["device"])
        model = UNet(T=modelConfig["T"], ch=modelConfig["channel"], ch_mult=modelConfig["channel_mult"], attn=modelConfig["attn"],
                     num_res_blocks=modelConfig["num_res_blocks"], dropout=0.)
        ckpt = torch.load(os.path.join(
            modelConfig["save_weight_dir"], modelConfig["test_load_weight"]), map_location=device)
        model.load_state_dict(ckpt)print("model load weight done.")
        model.eval()
        sampler = GaussianDiffusionSampler(
            model, modelConfig["beta_1"], modelConfig["beta_T"], modelConfig["T"]).to(device)# Sampled from standard normal distribution
        noisyImage = torch.randn(
            size=[modelConfig["batch_size"],3,32,32], device=device)
        saveNoisy = torch.clamp(noisyImage *0.5+0.5,0,1)
        save_image(saveNoisy, os.path.join(
            modelConfig["sampled_dir"], modelConfig["sampledNoisyImgName"]), nrow=modelConfig["nrow"])
        sampledImgs = sampler(noisyImage)
        sampledImgs = sampledImgs *0.5+0.5# [0 ~ 1]
        save_image(sampledImgs, os.path.join(
            modelConfig["sampled_dir"],  modelConfig["sampledImgName"]), nrow=modelConfig["nrow"])

四、Main.py

定义主函数用于启动DDPM的训练或评估,取决于传入的state键值。

from Diffusion.Train import train,evaldefmain(model_config =None):
    modelConfig ={"state":"train",# or eval"epoch":200,"batch_size":80,"T":1000,"channel":128,"channel_mult":[1,2,3,4],"attn":[2],"num_res_blocks":2,"dropout":0.15,"lr":1e-4,"multiplier":2.,"beta_1":1e-4,"beta_T":0.02,"img_size":32,"grad_clip":1.,"device":"cuda:0",### MAKE SURE YOU HAVE A GPU !!!"training_load_weight":None,"save_weight_dir":"./Checkpoints/","test_load_weight":"ckpt_199_.pt","sampled_dir":"./SampledImgs/","sampledNoisyImgName":"NoisyNoGuidenceImgs.png","sampledImgName":"SampledNoGuidenceImgs.png","nrow":8}if model_config isnotNone:
        modelConfig = model_config
    if modelConfig["state"]=="train":
        train(modelConfig)else:eval(modelConfig)if __name__ =='__main__':
    main()

本文转载自: https://blog.csdn.net/weixin_72965820/article/details/140587611
版权归原作者 Smiling639 所有, 如有侵权,请联系我们删除。

“DDPM代码实现详解”的评论:

还没有评论