0


IDDPM官方gituhb项目--训练

在完成IDDPM论文学习后,对github上的官方仓库进行学习,通过具体的代码理解算法实现过程中的一些细节;官方仓库代码基于pytorch实现,链接为https://github.com/openai/improved-diffusion。本笔记主要针对项目中**训练部分代码**进行注释解析,主要涉及仓库项目中的image_train.py、script_util.py、train_util.py、resample.py、dist_util.py文件。

文章目录

image_train.py

本文件是进行图像训练的主要接口,先为训练过程中模型和扩散过程定义所需的参数,然后调用script_util.py文件中定义的函数初始化Unet模型和扩散过程对象,完成模型参数加载和训练数据导入后调用train_util.py文件中定义的TrainLoop类的run_loop()函数开始训练。

"""
Train a diffusion model on images.
"""import argparse

from improved_diffusion import dist_util, logger
from improved_diffusion.image_datasets import load_data
from improved_diffusion.resample import create_named_schedule_sampler
from improved_diffusion.script_util import(
    model_and_diffusion_defaults,
    create_model_and_diffusion,
    args_to_dict,
    add_dict_to_argparser,)from improved_diffusion.train_util import TrainLoop

defmain():
    args = create_argparser().parse_args()# 设置模型和训练所需参数

    dist_util.setup_dist()# 分布式训练
    logger.configure()

    logger.log("creating model and diffusion...")# 初始化UNet和diffusion框架
    model, diffusion = create_model_and_diffusion(**args_to_dict(args, model_and_diffusion_defaults().keys()))
    model.to(dist_util.dev())# 返回前向过程中的时刻t的采样器,分均匀采样和基于loss的采样# args.schedule_sampler设置为loss-second-moment可进行重要性采样,论文中用其在只优化L_vbl时减少梯度噪声
    schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, diffusion)

    logger.log("creating data loader...")
    data = load_data(
        data_dir=args.data_dir,
        batch_size=args.batch_size,
        image_size=args.image_size,
        class_cond=args.class_cond,)

    logger.log("training...")# TrainLoop是主要训练对象
    TrainLoop(
        model=model,# 用于逆扩散过程中拟合p_theta的模型,一般是Unet
        diffusion=diffusion,# 扩散过程对象
        data=data,
        batch_size=args.batch_size,
        microbatch=args.microbatch,
        lr=args.lr,
        ema_rate=args.ema_rate,
        log_interval=args.log_interval,
        save_interval=args.save_interval,
        resume_checkpoint=args.resume_checkpoint,
        use_fp16=args.use_fp16,
        fp16_scale_growth=args.fp16_scale_growth,
        schedule_sampler=schedule_sampler,# 训练过程中batch中数据时间步t的采样器
        weight_decay=args.weight_decay,
        lr_anneal_steps=args.lr_anneal_steps,).run_loop()# 初始化模型构建和训练相关的超参数defcreate_argparser():'''从字典中自动生成argument parser'''
    defaults =dict(
        data_dir="",
        schedule_sampler="uniform",
        lr=1e-4,
        weight_decay=0.0,
        lr_anneal_steps=0,
        batch_size=1,
        microbatch=-1,# -1 disables microbatches
        ema_rate="0.9999",# comma-separated list of EMA values
        log_interval=10,
        save_interval=10000,
        resume_checkpoint="",
        use_fp16=False,
        fp16_scale_growth=1e-3,)
    defaults.update(model_and_diffusion_defaults())# 添加模型和扩散过程的参数
    parser = argparse.ArgumentParser()
    add_dict_to_argparser(parser, defaults)return parser

if __name__ =="__main__":
    main()

script_util.py

本文件主要定义了Unet模型和扩散过程对象生成的代码,也包括超分辨率的Unet模型。

import argparse
import inspect

from.import gaussian_diffusion as gd
from.respace import SpacedDiffusion, space_timesteps
from.unet import SuperResModel, UNetModel

NUM_CLASSES =1000# Unet模型和扩散过程对象所需参数defmodel_and_diffusion_defaults():"""
    Defaults for image training.
    """returndict(
        image_size=64,
        num_channels=128,
        num_res_blocks=2,
        num_heads=4,
        num_heads_upsample=-1,
        attention_resolutions="16,8",
        dropout=0.0,
        learn_sigma=False,
        sigma_small=False,
        class_cond=False,
        diffusion_steps=1000,
        noise_schedule="linear",
        timestep_respacing="",
        use_kl=False,
        predict_xstart=False,
        rescale_timesteps=True,
        rescale_learned_sigmas=True,
        use_checkpoint=False,
        use_scale_shift_norm=True,)# 生成Unet模型和高斯扩散过程对象defcreate_model_and_diffusion(
    image_size,# 图片大小
    class_cond,# 生成模型是否有条件;一般就是图片有label信息
    learn_sigma,# 设置模型是预测方差还是使用固定方差
    sigma_small,
    num_channels,
    num_res_blocks,
    num_heads,
    num_heads_upsample,
    attention_resolutions,# 在哪些restblock上进行attention;存放图片的分辨率,当图片降维至该分辨率屎进行自注意力计算
    dropout,
    diffusion_steps,
    noise_schedule,
    timestep_respacing,
    use_kl,
    predict_xstart,
    rescale_timesteps,
    rescale_learned_sigmas,
    use_checkpoint,
    use_scale_shift_norm,):# Unet模型初始化
    model = create_model(
        image_size,
        num_channels,
        num_res_blocks,
        learn_sigma=learn_sigma,
        class_cond=class_cond,
        use_checkpoint=use_checkpoint,
        attention_resolutions=attention_resolutions,
        num_heads=num_heads,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        dropout=dropout,)# 初始化高斯扩散过程
    diffusion = create_gaussian_diffusion(
        steps=diffusion_steps,
        learn_sigma=learn_sigma,
        sigma_small=sigma_small,
        noise_schedule=noise_schedule,
        use_kl=use_kl,
        predict_xstart=predict_xstart,
        rescale_timesteps=rescale_timesteps,
        rescale_learned_sigmas=rescale_learned_sigmas,
        timestep_respacing=timestep_respacing,)return model, diffusion

# 生成Unet模型defcreate_model(
    image_size,
    num_channels,
    num_res_blocks,
    learn_sigma,
    class_cond,
    use_checkpoint,
    attention_resolutions,# 表示Unet中进行自注意力计算是特征图的分辨率,就是尺寸大小,用于告诉模型何时进行自注意力计算
    num_heads,
    num_heads_upsample,
    use_scale_shift_norm,
    dropout,):# Unet架构中通道乘子,因为随着模型深入,特征图空间尺寸降低,但通道数逐渐增加if image_size ==256:
        channel_mult =(1,1,2,2,4,4)elif image_size ==64:
        channel_mult =(1,2,3,4)elif image_size ==32:
        channel_mult =(1,2,2,2)else:raise ValueError(f"unsupported image size: {image_size}")

    attention_ds =[]for res in attention_resolutions.split(","):# attention_resolutions是[16, 8]
        attention_ds.append(image_size //int(res))# attention_ds是[4, 8],原始尺寸大小除以下采样后的大小就是下采样率return UNetModel(
        in_channels=3,
        model_channels=num_channels,
        out_channels=(3ifnot learn_sigma else6),# 如果设置可学习方差sigma,输出维度就是6,分成两部分,分别预测miu和sigma
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),# 此处已经表示的是需要进行自注意力计算时的下采样率
        dropout=dropout,
        channel_mult=channel_mult,
        num_classes=(NUM_CLASSES if class_cond elseNone),
        use_checkpoint=use_checkpoint,
        num_heads=num_heads,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,)# 超分辨率Unet模型和扩散过程对象所需参数defsr_model_and_diffusion_defaults():
    res = model_and_diffusion_defaults()
    res["large_size"]=256
    res["small_size"]=64
    arg_names = inspect.getfullargspec(sr_create_model_and_diffusion)[0]for k in res.copy().keys():if k notin arg_names:del res[k]return res

# 生成超分辨率Unet和高斯扩散过程对象defsr_create_model_and_diffusion(
    large_size,
    small_size,
    class_cond,
    learn_sigma,
    num_channels,
    num_res_blocks,
    num_heads,
    num_heads_upsample,
    attention_resolutions,
    dropout,
    diffusion_steps,
    noise_schedule,
    timestep_respacing,
    use_kl,
    predict_xstart,
    rescale_timesteps,
    rescale_learned_sigmas,
    use_checkpoint,
    use_scale_shift_norm,):
    model = sr_create_model(
        large_size,
        small_size,
        num_channels,
        num_res_blocks,
        learn_sigma=learn_sigma,
        class_cond=class_cond,
        use_checkpoint=use_checkpoint,
        attention_resolutions=attention_resolutions,
        num_heads=num_heads,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,
        dropout=dropout,)
    diffusion = create_gaussian_diffusion(
        steps=diffusion_steps,
        learn_sigma=learn_sigma,
        noise_schedule=noise_schedule,
        use_kl=use_kl,
        predict_xstart=predict_xstart,
        rescale_timesteps=rescale_timesteps,
        rescale_learned_sigmas=rescale_learned_sigmas,
        timestep_respacing=timestep_respacing,)return model, diffusion

# 生成超分辨率Unetdefsr_create_model(
    large_size,
    small_size,
    num_channels,
    num_res_blocks,
    learn_sigma,
    class_cond,
    use_checkpoint,
    attention_resolutions,
    num_heads,
    num_heads_upsample,
    use_scale_shift_norm,
    dropout,):
    _ = small_size  # hack to prevent unused variableif large_size ==256:
        channel_mult =(1,1,2,2,4,4)elif large_size ==64:
        channel_mult =(1,2,3,4)else:raise ValueError(f"unsupported large size: {large_size}")

    attention_ds =[]for res in attention_resolutions.split(","):
        attention_ds.append(large_size //int(res))return SuperResModel(
        in_channels=3,
        model_channels=num_channels,
        out_channels=(3ifnot learn_sigma else6),
        num_res_blocks=num_res_blocks,
        attention_resolutions=tuple(attention_ds),
        dropout=dropout,
        channel_mult=channel_mult,
        num_classes=(NUM_CLASSES if class_cond elseNone),
        use_checkpoint=use_checkpoint,
        num_heads=num_heads,
        num_heads_upsample=num_heads_upsample,
        use_scale_shift_norm=use_scale_shift_norm,)# 生成扩散过程的框架;虽然初始化的是SpacedDiffusion类,但只要不进行respace,就是一个常规的GaussianDiffusion类defcreate_gaussian_diffusion(*,
    steps=1000,
    learn_sigma=False,
    sigma_small=False,
    noise_schedule="linear",
    use_kl=False,
    predict_xstart=False,
    rescale_timesteps=False,
    rescale_learned_sigmas=False,
    timestep_respacing="",):'''生成扩散模型的框架'''
    betas = gd.get_named_beta_schedule(noise_schedule, steps)# 设置前向的加噪方案,即设置β;可选择设置IDDPM论文提出的余弦加噪方案if use_kl:
        loss_type = gd.LossType.RESCALED_KL  # 只是用kl损失elif rescale_learned_sigmas:
        loss_type = gd.LossType.RESCALED_MSE  # 使用混合损失else:
        loss_type = gd.LossType.MSE  # 使用原始DDPM的损失ifnot timestep_respacing:
        timestep_respacing =[steps]# 调整时间步空间???return SpacedDiffusion(
        use_timesteps=space_timesteps(steps, timestep_respacing),
        betas=betas,
        model_mean_type=(
            gd.ModelMeanType.EPSILON ifnot predict_xstart else gd.ModelMeanType.START_X
        ),# Unet模型预测的值是x_0还是均值
        model_var_type=((
                gd.ModelVarType.FIXED_LARGE
                ifnot sigma_small
                else gd.ModelVarType.FIXED_SMALL
            )ifnot learn_sigma
            else gd.ModelVarType.LEARNED_RANGE
        ),# Unet模型预测的方差是可学习方差,还是使用的固定方差,固定方差中又分大的beta_t或小的beta_bar_t
        loss_type=loss_type,# 损失类型
        rescale_timesteps=rescale_timesteps,)# 将default_dict字典中的参数添加到parser对象中defadd_dict_to_argparser(parser, default_dict):for k, v in default_dict.items():
        v_type =type(v)if v isNone:
            v_type =strelifisinstance(v,bool):
            v_type = str2bool
        parser.add_argument(f"--{k}", default=v,type=v_type)# 从args中按传入的keys构建一个对应的参数字典defargs_to_dict(args, keys):return{k:getattr(args, k)for k in keys}defstr2bool(v):"""
    https://stackoverflow.com/questions/15008758/parsing-boolean-values-with-argparse
    """ifisinstance(v,bool):return v
    if v.lower()in("yes","true","t","y","1"):returnTrueelif v.lower()in("no","false","f","n","0"):returnFalseelse:raise argparse.ArgumentTypeError("boolean value expected")

train_util.py

本文件中主要定义了用于训练的TrainLoop类,其内部是将单个step、单个batch、整体训练过程解耦的,实现方式与PytorchLightning类似。本人认为需要注意的一点的是,为了保证模型进行合理的混合精度训练,TrainLoop类中维护了一个self.master_params变量。本人是将其理解为训练过程中Unet模型参数的一份全精度的备份,混合精度训练时,在训练过程中计算时,模型中是使用半精度类型数据进行计算,但是在进行梯度回传时,会将模型内部的梯度值传递到self.master_params变量中存储的参数进行全精度的参数更新,然后再把更新后的参数传回到模型参数中,完成梯度回传和模型参数更新。训练过程中,会使用到IDDPM提出的一个改善点,即对时间步进行基于损失的重要性重采样。

import copy
import functools
import os

import blobfile as bf
import numpy as np
import torch as th
import torch.distributed as dist
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW

from.import dist_util, logger
from.fp16_util import(
    make_master_params,
    master_params_to_model_params,
    model_grads_to_master_grads,
    unflatten_master_params,
    zero_grad,)from.nn import update_ema
from.resample import LossAwareSampler, UniformSampler

# For ImageNet experiments, this was a good default value.# We found that the lg_loss_scale quickly climbed to# 20-21 within the first ~1K steps of training.
INITIAL_LOG_LOSS_SCALE =20.0# 定义的模型训练类,有点类似PytorchLightning,将训练过程封装为一个接口classTrainLoop:def__init__(
            self,*,
            model,
            diffusion,
            data,
            batch_size,
            microbatch,
            lr,
            ema_rate,
            log_interval,
            save_interval,
            resume_checkpoint,
            use_fp16=False,
            fp16_scale_growth=1e-3,
            schedule_sampler=None,
            weight_decay=0.0,
            lr_anneal_steps=0,):
        self.model = model  # Unet模型
        self.diffusion = diffusion  # 扩散过程对象
        self.data = data  # 训练数据
        self.batch_size = batch_size
        self.microbatch = microbatch if microbatch >0else batch_size  # 多卡训练单卡上的batch???
        self.lr = lr
        self.ema_rate =([ema_rate]ifisinstance(ema_rate,float)else[float(x)for x in ema_rate.split(",")])
        self.log_interval = log_interval  # 日志记录间隔
        self.save_interval = save_interval  # 模型保存间隔
        self.resume_checkpoint = resume_checkpoint
        self.use_fp16 = use_fp16  # 是否进行半精度训练
        self.fp16_scale_growth = fp16_scale_growth
        self.schedule_sampler = schedule_sampler or UniformSampler(diffusion)# 时间步采样器
        self.weight_decay = weight_decay
        self.lr_anneal_steps = lr_anneal_steps  # 学习率回火steps

        self.step =0
        self.resume_step =0
        self.global_batch = self.batch_size * dist.get_world_size()

        self.model_params =list(self.model.parameters())
        self.master_params = self.model_params
        self.lg_loss_scale = INITIAL_LOG_LOSS_SCALE
        self.sync_cuda = th.cuda.is_available()

        self._load_and_sync_parameters()# Unet模型加载参数并同步到多卡上if self.use_fp16:# 如果使用半精度训练
            self._setup_fp16()# 先将模型参数以全精度形式备份,然后将其转为半精度

        self.opt = AdamW(self.master_params, lr=self.lr, weight_decay=self.weight_decay)# 优化器if self.resume_step:# 如果时在断点上继续训练
            self._load_optimizer_state()# 给优化器加载参数# Model was resumed, either due to a restart or a checkpoint being specified at the command line.
            self.ema_params =[self._load_ema_parameters(rate)for rate in self.ema_rate]else:
            self.ema_params =[copy.deepcopy(self.master_params)for _ inrange(len(self.ema_rate))]if th.cuda.is_available():
            self.use_ddp =True# 分布式训练
            self.ddp_model = DDP(
                self.model,
                device_ids=[dist_util.dev()],
                output_device=dist_util.dev(),
                broadcast_buffers=False,
                bucket_cap_mb=128,
                find_unused_parameters=False,)else:if dist.get_world_size()>1:
                logger.warn("Distributed training requires CUDA. ""Gradients will not be synchronized properly!")
            self.use_ddp =False
            self.ddp_model = self.model

    # 分布式训练中模型加载和同步所有参数def_load_and_sync_parameters(self):
        resume_checkpoint = find_resume_checkpoint()or self.resume_checkpoint  # 已经训练过程checkpoint保存的文件路径if resume_checkpoint:
            self.resume_step = parse_resume_step_from_filename(resume_checkpoint)# 从文件路径中解析处保存时的stepif dist.get_rank()==0:
                logger.log(f"loading model from checkpoint: {resume_checkpoint}...")
                self.model.load_state_dict(
                    dist_util.load_state_dict(
                        resume_checkpoint, map_location=dist_util.dev()))# 给Unet模型加载参数

        dist_util.sync_params(self.model.parameters())# 给多卡上的模型同步参数def_load_ema_parameters(self, rate):
        ema_params = copy.deepcopy(self.master_params)

        main_checkpoint = find_resume_checkpoint()or self.resume_checkpoint
        ema_checkpoint = find_ema_checkpoint(main_checkpoint, self.resume_step, rate)if ema_checkpoint:if dist.get_rank()==0:
                logger.log(f"loading EMA from checkpoint: {ema_checkpoint}...")
                state_dict = dist_util.load_state_dict(
                    ema_checkpoint, map_location=dist_util.dev())
                ema_params = self._state_dict_to_master_params(state_dict)

        dist_util.sync_params(ema_params)return ema_params

    # 给优化器加载参数def_load_optimizer_state(self):
        main_checkpoint = find_resume_checkpoint()or self.resume_checkpoint
        opt_checkpoint = bf.join(
            bf.dirname(main_checkpoint),f"opt{self.resume_step:06}.pt")# 优化器参数保存的文件路径if bf.exists(opt_checkpoint):
            logger.log(f"loading optimizer state from checkpoint: {opt_checkpoint}")
            state_dict = dist_util.load_state_dict(
                opt_checkpoint, map_location=dist_util.dev())
            self.opt.load_state_dict(state_dict)# 将模型的参数设置为半精度def_setup_fp16(self):
        self.master_params = make_master_params(self.model_params)# 先将模型的参数以全精度的格式备份一份
        self.model.convert_to_fp16()# 让后再将模型所有的参数转为半精度# 主要的训练函数defrun_loop(self):while(not self.lr_anneal_steps
                or self.step + self.resume_step < self.lr_anneal_steps
        ):
            batch, cond =next(self.data)# 一个batch的数据,cond应该是label等条件信息
            self.run_step(batch, cond)# 执行一个batch的训练过程if self.step % self.log_interval ==0:
                logger.dumpkvs()if self.step % self.save_interval ==0:
                self.save()# 模型、优化器等参数保存# Run for a finite amount of time in integration tests.if os.environ.get("DIFFUSION_TRAINING_TEST","")and self.step >0:return
            self.step +=1# Save the last checkpoint if it wasn't already saved.if(self.step -1)% self.save_interval !=0:
            self.save()# 单个batch的训练函数defrun_step(self, batch, cond):
        self.forward_backward(batch, cond)if self.use_fp16:
            self.optimize_fp16()#else:
            self.optimize_normal()# 优化器更新
        self.log_step()defforward_backward(self, batch, cond):
        zero_grad(self.model_params)# 清除模型参数的梯度for i inrange(0, batch.shape[0], self.microbatch):
            micro = batch[i: i + self.microbatch].to(dist_util.dev())
            micro_cond ={
                k: v[i: i + self.microbatch].to(dist_util.dev())for k, v in cond.items()}
            last_batch =(i + self.microbatch)>= batch.shape[0]# 随后的一个microbatch
            t, weights = self.schedule_sampler.sample(micro.shape[0], dist_util.dev())# 时间步采样# 使用functools.partial接口传入在扩散过程对象中定义损失计算函数和所需参数定义了一个用于计算损失函数
            compute_losses = functools.partial(
                self.diffusion.training_losses,
                self.ddp_model,
                micro,
                t,
                model_kwargs=micro_cond,)# 实际损失计算if last_batch ornot self.use_ddp:
                losses = compute_losses()else:with self.ddp_model.no_sync():# 在非同步的情况下计算损失
                    losses = compute_losses()ifisinstance(self.schedule_sampler, LossAwareSampler):
                self.schedule_sampler.update_with_local_losses(
                    t, losses["loss"].detach())# 论文中提出的改善点,基于训练损失的时间步重要性重采样

            loss =(losses["loss"]* weights).mean()
            log_loss_dict(self.diffusion, t,{k: v * weights for k, v in losses.items()})# 记录训练损失# 梯度回传if self.use_fp16:
                loss_scale =2** self.lg_loss_scale
                (loss * loss_scale).backward()else:
                loss.backward()# 半精度训练时优化器更新defoptimize_fp16(self):ifany(not th.isfinite(p.grad).all()for p in self.model_params):
            self.lg_loss_scale -=1
            logger.log(f"Found NaN, decreased lg_loss_scale to {self.lg_loss_scale}")return

        model_grads_to_master_grads(self.model_params, self.master_params)# 将当前模型参数的梯度赋值给self.master_params,即备份参数
        self.master_params[0].grad.mul_(1.0/(2** self.lg_loss_scale))
        self._log_grad_norm()# 记录模型所有参数计算的正则项
        self._anneal_lr()# 学习率回火
        self.opt.step()# 优化器更新for rate, params inzip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)# 使用指数移动平均值更新目标参数,使其更接近源参数;即模型参数更新# 将备份的参数再传回给模型参数,这么做的目的应该是训练可以使用半精度,但是在模型梯度更新时还是要使用全精度
        master_params_to_model_params(self.model_params, self.master_params)
        self.lg_loss_scale += self.fp16_scale_growth

    # 全精度训练时优化器更新defoptimize_normal(self):
        self._log_grad_norm()# 记录模型所有参数计算的正则项
        self._anneal_lr()# 学习率回火
        self.opt.step()# 优化器更新for rate, params inzip(self.ema_rate, self.ema_params):
            update_ema(params, self.master_params, rate=rate)# 使用指数移动平均值更新目标参数,使其更接近源参数;即模型参数更新def_log_grad_norm(self):
        sqsum =0.0for p in self.master_params:
            sqsum +=(p.grad **2).sum().item()
        logger.logkv_mean("grad_norm", np.sqrt(sqsum))def_anneal_lr(self):ifnot self.lr_anneal_steps:return
        frac_done =(self.step + self.resume_step)/ self.lr_anneal_steps
        lr = self.lr *(1- frac_done)# 新的学习率值for param_group in self.opt.param_groups:
            param_group["lr"]= lr  # 给优化器各参数更新学习率deflog_step(self):
        logger.logkv("step", self.step + self.resume_step)
        logger.logkv("samples",(self.step + self.resume_step +1)* self.global_batch)if self.use_fp16:
            logger.logkv("lg_loss_scale", self.lg_loss_scale)# 模型参数保存defsave(self):defsave_checkpoint(rate, params):
            state_dict = self._master_params_to_state_dict(params)if dist.get_rank()==0:
                logger.log(f"saving model {rate}...")ifnot rate:
                    filename =f"model{(self.step + self.resume_step):06d}.pt"else:
                    filename =f"ema_{rate}_{(self.step + self.resume_step):06d}.pt"with bf.BlobFile(bf.join(get_blob_logdir(), filename),"wb")as f:
                    th.save(state_dict, f)

        save_checkpoint(0, self.master_params)# 保存模型参数for rate, params inzip(self.ema_rate, self.ema_params):
            save_checkpoint(rate, params)# 保存ema_params参数if dist.get_rank()==0:with bf.BlobFile(
                    bf.join(get_blob_logdir(),f"opt{(self.step + self.resume_step):06d}.pt"),"wb",)as f:
                th.save(self.opt.state_dict(), f)# 保存优化器参数

        dist.barrier()# 所有参数同步def_master_params_to_state_dict(self, master_params):if self.use_fp16:
            master_params = unflatten_master_params(self.model.parameters(), master_params)# 将master_params尺寸还原
        state_dict = self.model.state_dict()# 模型所有的参数for i,(name, _value)inenumerate(self.model.named_parameters()):assert name in state_dict
            state_dict[name]= master_params[i]# 以模型参数中name从备份参数master_parems中找对应的值return state_dict

    # 从state_dict中将参数传给master_params中def_state_dict_to_master_params(self, state_dict):
        params =[state_dict[name]for name, _ in self.model.named_parameters()]if self.use_fp16:return make_master_params(params)else:return params

# 从一个checkpoint文件路径名中解析保存时的stepdefparse_resume_step_from_filename(filename):"""
    Parse filenames of the form path/to/modelNNNNNN.pt, where NNNNNN is the
    checkpoint's number of steps.
    """
    split = filename.split("model")iflen(split)<2:return0
    split1 = split[-1].split(".")[0]try:returnint(split1)except ValueError:return0defget_blob_logdir():return os.environ.get("DIFFUSION_BLOB_LOGDIR", logger.get_dir())deffind_resume_checkpoint():# On your infrastructure, you may want to override this to automatically# discover the latest checkpoint on your blob storage, etc.returnNonedeffind_ema_checkpoint(main_checkpoint, step, rate):if main_checkpoint isNone:returnNone
    filename =f"ema_{rate}_{(step):06d}.pt"
    path = bf.join(bf.dirname(main_checkpoint), filename)if bf.exists(path):return path
    returnNone# 记录损失deflog_loss_dict(diffusion, ts, losses):for key, values in losses.items():
        logger.logkv_mean(key, values.mean().item())# Log the quantiles (four quartiles, in particular).for sub_t, sub_loss inzip(ts.cpu().numpy(), values.detach().cpu().numpy()):
            quartile =int(4* sub_t / diffusion.num_timesteps)
            logger.logkv_mean(f"{key}_q{quartile}", sub_loss)

resample.py

IDDPM论文中提到,训练时对时间步进行均匀采样会在

      L 
     
     
     
       v 
      
     
       l 
      
     
       b 
      
     
    
   
  
    L_{vlb} 
   
  
Lvlb​中引入不必要的噪声,为了解决该问题就是进行重要性重采样,具体的实现方式就是在随机进行时间步采样时,会使用一个动态更新的历史损失为各个时间步计算对应的采样权重,而不是进行各个时间步采样概率相同的均匀采样。该动态更新的历史损失,就是构造一个尺寸为[T, 10]的矩阵,即为整个扩散过程的T步存储10个最新的损失值;在将该历史损失矩阵所有的值填满前,还是进行均匀采样;在填满之后,基于损失为每个时间步计算随机采样时的权重,为下一个step训练时更新时间步的采样权重;并且如果使用该类型的时间步sampler,在一个step中损失计算后,还需要调用update_with_local_losses函数将新计算得到的损失填入到历史损失矩阵最后端,并将最前端的最旧的历史损失弹出。
from abc import ABC, abstractmethod

import numpy as np
import torch as th
import torch.distributed as dist

# 返回时间步的采样器defcreate_named_schedule_sampler(name, diffusion):"""
    Create a ScheduleSampler from a library of pre-defined samplers.

    :param name: the name of the sampler.
    :param diffusion: the diffusion object to sample for.
    """if name =="uniform":return UniformSampler(diffusion)# 均匀采样elif name =="loss-second-moment":return LossSecondMomentResampler(diffusion)# 基于二阶动量平滑losselse:raise NotImplementedError(f"unknown schedule sampler: {name}")classScheduleSampler(ABC):"""
    A distribution over timesteps in the diffusion process, intended to reduce
    variance of the objective.扩散过程中随时间步长的分布,旨在减少目标的方差

    By default, samplers perform unbiased importance sampling, in which the
    objective's mean is unchanged.默认情况下,采样器执行无偏重要性抽样,其中目标的均值保持不变。
    However, subclasses may override sample() to change how the resampled
    terms are reweighted, allowing for actual changes in the objective.
    但是,子类可以覆盖 sample() 以更改重新采样项的重新加权方式,从而允许目标的实际更改。
    """@abstractmethoddefweights(self):"""
        Get a numpy array of weights, one per diffusion step.

        The weights needn't be normalized, but must be positive.
        """# 一个batch内数据的重要性采样时间步defsample(self, batch_size, device):"""
        Importance-sample timesteps for a batch.

        :param batch_size: the number of timesteps.
        :param device: the torch device to save to.
        :return: a tuple (timesteps, weights):
                 - timesteps: a tensor of timestep indices.
                 - weights: a tensor of weights to scale the resulting losses.
        """
        w = self.weights()# 所有时间步的权重
        p = w / np.sum(w)# 每个时间步的权重除去权重之和# 从range(len(p))中以概率p随机抽取大小为size的数据;p指定的是序列range(len(p))中每个元素出现的概率
        indices_np = np.random.choice(len(p), size=(batch_size,), p=p)# 相当于是概率p为指导从range(len(p))随机采样了batch_size个t对应的索引
        indices = th.from_numpy(indices_np).long().to(device)
        weights_np =1/(len(p)* p[indices_np])# 为batch中每个对象设置新的权重
        weights = th.from_numpy(weights_np).float().to(device)return indices, weights

# 时间步均匀采样classUniformSampler(ScheduleSampler):def__init__(self, diffusion):
        self.diffusion = diffusion
        # 权重均为1,使得ScheduleSampler的sample函数中的p的概率都是一样的,故使用np.random.choice采样时是均匀采样
        self._weights = np.ones([diffusion.num_timesteps])defweights(self):# 重载ScheduleSampler中的weights函数return self._weights

# 使用损失更新weights的重要性采样classLossAwareSampler(ScheduleSampler):defupdate_with_local_losses(self, local_ts, local_losses):"""
        Update the reweighting using losses from a model.使用模型中的损失更新重新加权

        Call this method from each rank with a batch of timesteps and the
        corresponding losses for each of those timesteps.
        This method will perform synchronization to make sure all of the ranks
        maintain the exact same reweighting.

        :param local_ts: an integer Tensor of timesteps.时间步的整数张量
        :param local_losses: a 1D Tensor of losses.损失的一维张量
        """
        batch_sizes =[
            th.tensor([0], dtype=th.int32, device=local_ts.device)for _ inrange(dist.get_world_size())]
        dist.all_gather(
            batch_sizes,
            th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),)# Pad all_gather batches to be the maximum batch size.
        batch_sizes =[x.item()for x in batch_sizes]
        max_bs =max(batch_sizes)# 将多卡上的batch整合后的最大batch_size

        timestep_batches =[th.zeros(max_bs).to(local_ts)for bs in batch_sizes]
        loss_batches =[th.zeros(max_bs).to(local_losses)for bs in batch_sizes]
        dist.all_gather(timestep_batches, local_ts)
        dist.all_gather(loss_batches, local_losses)
        timesteps =[
            x.item()for y, bs inzip(timestep_batches, batch_sizes)for x in y[:bs]]
        losses =[x.item()for y, bs inzip(loss_batches, batch_sizes)for x in y[:bs]]
        self.update_with_all_losses(timesteps, losses)@abstractmethoddefupdate_with_all_losses(self, ts, losses):"""
        Update the reweighting using losses from a model.

        Sub-classes should override this method to update the reweighting
        using losses from the model.

        This method directly updates the reweighting without synchronizing
        between workers. It is called by update_with_local_losses from all
        ranks with identical arguments. Thus, it should have deterministic
        behavior to maintain state across workers.

        :param ts: a list of int timesteps.
        :param losses: a list of float losses, one per timestep.
        """classLossSecondMomentResampler(LossAwareSampler):def__init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
        self.diffusion = diffusion
        self.history_per_term = history_per_term  # 论文中提到的“保留每个损失项的前10个值”,是针对每个时间步t
        self.uniform_prob = uniform_prob
        self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64
        )# diffusion.num_timesteps是设置的训练采样的总步数,即T;故self._loss_history是为0到T-1中的每个时间步t存放10个损失值
        self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)# 表征self._loss_history中对应列是否填充defweights(self):# 重载的weights的函数ifnot self._warmed_up():# 未完成warm_up就进行均匀采样;即self._loss_history中数据还未填满return np.ones([self.diffusion.num_timesteps], dtype=np.float64)# 基于历史损失的权重更新
        weights = np.sqrt(np.mean(self._loss_history **2, axis=-1))# 对历史损失平方后取均值再开方
        weights /= np.sum(weights)
        weights *=1- self.uniform_prob
        weights += self.uniform_prob /len(weights)return weights

    # 会改变类中的self._loss_history进而改变self.weights的返回数据,进而改变self.sampler中的采样结果defupdate_with_all_losses(self, ts, losses):for t, loss inzip(ts, losses):if self._loss_counts[t]== self.history_per_term:# 如果self._loss_history已经填满# Shift out the oldest loss term.移除第一列损失,将新的损失补充为最后一列
                self._loss_history[t,:-1]= self._loss_history[t,1:]
                self._loss_history[t,-1]= loss
            else:# 如果self._loss_history未填满
                self._loss_history[t, self._loss_counts[t]]= loss  # 用新传入的损失补充为最后一列
                self._loss_counts[t]+=1# 填充列数加一# 用于判断self._loss_history中的数据是否填满,填满之前都是进行均匀采样def_warmed_up(self):return(self._loss_counts == self.history_per_term).all()

dist_util.py

本文件主要为多卡分布式训练定义一些辅助函数

"""
Helpers for distributed training.
"""import io
import os
import socket

import blobfile as bf
from mpi4py import MPI
import torch as th
import torch.distributed as dist

# Change this to reflect your cluster layout.# The GPU for a given rank is (rank % GPUS_PER_NODE).
GPUS_PER_NODE =8

SETUP_RETRY_COUNT =3defsetup_dist():"""
    Setup a distributed process group.
    """if dist.is_initialized():return

    comm = MPI.COMM_WORLD
    backend ="gloo"ifnot th.cuda.is_available()else"nccl"if backend =="gloo":
        hostname ="localhost"else:
        hostname = socket.gethostbyname(socket.getfqdn())
    os.environ["MASTER_ADDR"]= comm.bcast(hostname, root=0)
    os.environ["RANK"]=str(comm.rank)
    os.environ["WORLD_SIZE"]=str(comm.size)

    port = comm.bcast(_find_free_port(), root=0)
    os.environ["MASTER_PORT"]=str(port)
    dist.init_process_group(backend=backend, init_method="env://")defdev():"""
    Get the device to use for torch.distributed.
    """if th.cuda.is_available():return th.device(f"cuda:{MPI.COMM_WORLD.Get_rank()% GPUS_PER_NODE}")return th.device("cpu")defload_state_dict(path,**kwargs):"""
    Load a PyTorch file without redundant fetches across MPI ranks.
    """if MPI.COMM_WORLD.Get_rank()==0:with bf.BlobFile(path,"rb")as f:
            data = f.read()else:
        data =None
    data = MPI.COMM_WORLD.bcast(data)return th.load(io.BytesIO(data),**kwargs)# 从主GPU开始同步张量序列defsync_params(params):"""
    Synchronize a sequence of Tensors across ranks from rank 0.
    """for p in params:with th.no_grad():
            dist.broadcast(p,0)def_find_free_port():try:
        s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        s.bind(("",0))
        s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR,1)return s.getsockname()[1]finally:
        s.close()

本笔记主要记录IDDPM官方仓库中训练部分相关代码,其中包含了IDDPM的一个改善点,即基于损失的时间步重要性重采样。本笔记中的项目代码虽然没有模型构造中那么多公式,但也最好能与论文对比学习,读者可参考此笔记IDDPM论文阅读辅助理解。读者若发现问题或错误,请评论指出,互相学习。


本文转载自: https://blog.csdn.net/zzfive/article/details/128060655
版权归原作者 zzfive 所有, 如有侵权,请联系我们删除。

“IDDPM官方gituhb项目--训练”的评论:

还没有评论