Pytorch 多卡并行训练教程 (DDP)
在使用GPU训练大模型时,往往会面临单卡显存不足的情况,这时候就希望通过多卡并行的形式来扩大显存。PyTorch主要提供了两个类来实现多卡并行分别是
- torch.nn.DataParallel(DP)
- torch.nn.DistributedDataParallel(DDP)
关于这两者的区别和原理也有许多博客如Pytorch 并行训练(DP, DDP)的原理和应用; DDP系列第一篇:入门教程进行总结,这里就不在赘述了。不过总结来说的话:DP 比较简单,对小白比较友好,一行代码便可以搞定。DDP 每个进程对应一个独立的训练过程,且只对梯度等少量数据进行信息交换。每个进程包含独立的解释器和 GIL。
博主能力有限,很多原理上的东西看得不是特别懂,所以理解起来也比较肤浅,但是编程的时候一直没找到一套合适的蓝本,最终参考了很多网上的博客,吭哧吭哧写了一套不会报错的代码出来,下面把我个人的理解整理出来,不当之处希望大家指出,一起交流学习。后续可能会随着自己的理解的加深持续完善。
主要参考了以下一些博客:
- PyTorch 并行训练指南:单机多卡并行、混合精度、同步 BN 训练
- Pytorch 并行训练(DP, DDP)的原理和应用
- pytorch多gpu并行训练
- DDP系列第一篇:入门教程
- 单机多卡训练 踩坑记录
初始化
增加参数local_rank来确定当前进程使用哪块GPU, 用于在每个进程中指定不同的device。
defparse():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank',type=int, default=0)
args = parser.parse_args()return args
defmain():
args = parse()
torch.cuda.set_device(args.local_rank)
torch.distributed.init_process_group('nccl',
init_method='env://')
device = torch.device(f'cuda:{args.local_rank}')
其中 torch.distributed.init_process_group 用于初始化GPU通信方式(NCCL)和参数的获取方式(env代表通过环境变量)。
设置随机种子点
假如model中用到了随机数种子来保证可复现性, 那么此时不能再用固定的常数作为seed, 否则会导致DDP中的所有进程都拥有一样的seed, 进而生成同态性的数据, 因此需要在程序中显示地设置随机种子点。
# 固定随机种子点
seed = np.random.randint(1,10000)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
Dataloader
对于数据加载,在初始化 data loader 的时候需要使用到 torch.utils.data.distributed.DistributedSampler 这个函数:
train_dataset =...
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)# 这个sampler会自动分配数据到各个gpu上
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=opts.batch_size, sampler=train_sampler)
通过以上的函数便可以给每个进程一个不同的 sampler,告诉每个进程自己分别取哪些数据。
在每一个epoch开始的阶段需要为sampler重新设定eopch即:
for ep inrange(total_epoch):
train_sampler.set_epoch(ep)
这样做的目的是:如果在DistributedSampler设置了shuffle,DistributedSampler使用当前epoch作为随机数种子,从而使得不同epoch下有不同的shuffle结果,但是在DistributedSampler源代码中默认的epoch为0,那么每次dataloader获取的shuffle都是相同的。所以,每次 epoch 开始前都需要要调用 sampler 的 set_epoch 方法,这样才能让数据集随机 shuffle 起来。
模型初始化
对于模型的处理主要包括模型初始化,将模型加载至CUDA;加载预训练权重;或利用主进程的权重 初始化所有的进程;将模型中的BN转换为SyncBN;设置模型并行。
由于 BN 层需要基于传入模型的数据计算均值和方差,造成普通 BN 在多卡模式下实际上就是单卡模式。此时需要使用 SyncBN 利用DDP的分布式计算接口来实现真正的多卡BN。
SyncBN利用分布式通讯接口在各卡间进行通讯,传输各自进程小 batch mean 和小 batch variance,在传输少量数据的基础上利用所有数据进行BN计算。
同时由于 SyncBN 用到 all_gather 这个分布式计算接口,而使用这个接口需要先初始化DDP环境,因此 SyncBN 需要在 DDP 环境初始化后初始化,但是要在 DDP 模型前就准备好。
最后由于 SyncBN 是直接搜索 model 中每个 module,如果这个 module 是 torch.nn.modules.batchnorm._BatchNorm 的子类,就将其替换为 SyncBN。因此如果你的 Normalization 层是自己定义的特殊类,没有继承过 _BatchNorm 类,那么convert_sync_batchnorm 是不支持的,需要你自己实现一个新的SyncBN!
defparse():
parser = argparse.ArgumentParser()
parser.add_argument('--local_rank',type=int, default=0)
parser.add_argument('--device',type=str, default='cuda',help='device id (i.e. 0 or 0,1 or cpu)')
parser.add_argument('--resume',type=str, default=None,help='specified the dir of saved models for resume the training')
args = parser.parse_args()return args
args = parse()
device = torch.device(args.device)
model = mymodel().to(device)if args.resume:
checkpoint = torch.load(model_save_path, map_location=device)
model.load_state_dict(checkpoint['model'])else:
save_path ='initial_weights.pth'if opts.local_rank ==0:
torch.save(model.state_dict(), save_path)
dist.barrier()# 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
model.load_state_dict(torch.load(save_path, map_location=device))## 设置同步
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)## 设置模型并行
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True)## 注意要使用find_unused_parameters=True,因为有时候模型里面定义的一些模块 在forward函数里面没有调用,如果不使用find_unused_parameters=True 会报错
输出日志设置
在每一次需要输出或打印日志时都应该先使用
opts.local_rank == 0
来判断,也就是在主进程才执行一些操作,不然日志或者打印的结果会非常混乱。
logger =Noneif opts.local_rank ==0:
log_dir = os.path.join(opts.display_dir,'logger', opts.name)
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir,'log.txt')if os.path.exists(log_path):
os.remove(log_path)
logger = logger_config(log_path=log_path, logging_name='Timer')
logger.info('Parameter Space: ABS: {:.1f}, REL: {:.4f}'.format(count_parameters(MPF_model), count_parameters(MPF_model)/1024/1024))
logger.info(MPF_model)
模型保存
state ={'model':model.module.state_dict(),'ep':ep,'total_it':total_it}
save_path = os.path.join(self.model_dir,'model_{:0>5d}.pth'.format(ep))
torch.save(state, save_path)
在保存模型是需要注意的是,保存的是
{'model':model.module.state_dict()}
, 而不是我们之前的
{'model':model.state_dict()}
, 因为在使用DDP后,原来的model会被封装为新的model的module属性里。
启动方式
PyTorch为提供了一个很方便的启动器 torch.distributed.lunch 用于启动文件,所以可以将运行训练代码的方式调整成下面这样:
CUDA_VISIBLE_DEVICES=0,1 python -m torch.distributed.launch --nproc_per_node=2 train.py
最后附上完成了train代码和超参解析代码:
train.py
import torch.optim as optim
from create_dataset import*from utils import*from MPFNet_Trans_skip import MPFNet
from options import*from saver import Saver, resume
from time import time
from tqdm import tqdm
from optimizer import Optimizer
import datetime
import torch.distributed as dist
defmain():# parse options
parser = TrainOptions()
opts = parser.parse()# define model, optimiser and scheduler
torch.cuda.set_device(opts.local_rank)
torch.distributed.init_process_group('nccl', init_method='env://')# device = torch.device(f'cuda:{opts.local_rank}') #device 这样的设置可能会有问题
device = torch.device(opts.gpu)# device = torch.device("cuda:{}".format(opts.gpu) if torch.cuda.is_available() else "cpu")# 固定随机种子
seed = np.random.randint(1,10000)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)# define dataset
train_dataset = MSRSData(opts, is_train=True)
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset, shuffle=True)
train_loader = torch.utils.data.DataLoader(
dataset=train_dataset,
batch_size=opts.batch_size,
num_workers = opts.nThreads,
sampler=train_sampler,
pin_memory=False,)
test_dataset = MSRSData(opts, is_train=False)
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset)
test_loader = torch.utils.data.DataLoader(
dataset=test_dataset,
batch_size=12,
sampler=test_sampler,
num_workers = opts.nThreads,)## 先加载dataloader 计算每个epoch的的迭代步数 然后计算总的迭代步数
ep_iter =len(train_loader)
max_iter = opts.n_ep * ep_iter
if opts.local_rank ==0:print('Training iter: {}'.format(max_iter))print(opts.local_rank)## 初始化模型
MPF_model = MPFNet(opts.class_nb).to(device)
momentum =0.9
weight_decay =5e-4
lr_start =1e-3# max_iter = 150000
power =0.9
warmup_steps =1000
warmup_start_lr =1e-5
optimizer = Optimizer(
model = MPF_model,
lr0 = lr_start,
momentum = momentum,
wd = weight_decay,
warmup_steps = warmup_steps,
warmup_start_lr = warmup_start_lr,
max_iter = max_iter,
power = power)if opts.resume:if opts.local_rank ==0:
MPF_model, ep, total_it = resume(MPF_model, opts.resume, device)
optimizer = Optimizer(
model = MPF_model,
lr0 = lr_start,
momentum = momentum,
wd = weight_decay,
warmup_steps = warmup_steps,
warmup_start_lr = warmup_start_lr,
max_iter = max_iter,
power = power,
it=total_it)
lr = optimizer.get_lr()print('lr:{}'.format(lr))else:
model_dir = os.path.join(opts.result_dir, opts.name)
os.makedirs(model_dir, exist_ok=True)
save_path = os.path.join(model_dir,'initial_weights.pth')if opts.local_rank ==0:
torch.save(MPF_model.state_dict(), save_path)
dist.barrier()# 这里注意,一定要指定map_location参数,否则会导致第一块GPU占用更多资源
MPF_model.load_state_dict(torch.load(save_path, map_location=device))
ep =-1
total_it =0
ep +=1
MPF_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(MPF_model)
MPF_model = torch.nn.parallel.DistributedDataParallel(MPF_model, device_ids=[opts.local_rank], output_device=opts.local_rank, find_unused_parameters=True)# optimizer = optim.Adam(MPF_model.parameters(), lr=opts.lr)# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=50, gamma=0.9)
logger =Noneif opts.local_rank ==0:
log_dir = os.path.join(opts.display_dir,'logger', opts.name)
os.makedirs(log_dir, exist_ok=True)
log_path = os.path.join(log_dir,'log.txt')if os.path.exists(log_path):
os.remove(log_path)
logger = logger_config(log_path=log_path, logging_name='Timer')
logger.info('Parameter Space: ABS: {:.1f}, REL: {:.4f}'.format(count_parameters(MPF_model), count_parameters(MPF_model)/1024/1024))
logger.info(MPF_model)# Train and evaluate multi-task network
multi_task_trainer(train_loader,
train_sampler,
test_loader,
MPF_model,
device,
optimizer,
opts,
logger,
ep,
total_it)defmulti_task_trainer(train_loader, train_sampler, test_loader, multi_task_model, device, optimizer, opt, logger=None, start_ep=0, total_it=0):
total_epoch = opt.n_ep
saver = Saver(opt)## 计算分割损失相关的设计
score_thres =0.7
ignore_idx =255
n_min =8*256*256//8
criteria = OhemCELoss(
thresh=score_thres, n_min=n_min, device=device, ignore_lb=ignore_idx)
binary_class_weight = np.array([1.4548,19.8962])
binary_class_weight = torch.tensor(binary_class_weight).float().to(device)
binary_class_weight = binary_class_weight.unsqueeze(0)
binary_class_weight = binary_class_weight.unsqueeze(2)
binary_class_weight = binary_class_weight.unsqueeze(2)
lb_ignore =[255]if opt.resume:
best_mIou = multi_task_tester(test_loader, multi_task_model, device, opt)else:
best_mIou =0.0if opt.local_rank ==0:print('best mIoU: {:.4f}'.format(best_mIou))
start = glob_st = time()for ep inrange(start_ep, total_epoch):## 每一个epoch 计算一次动态权重
train_sampler.set_epoch(ep)
multi_task_model.train()
seg_metric = SegmentationMetric(opt.class_nb, device=device)## 这里可能会有问题 for it,(img_ir, img_vi, label, bi, bd, mask)inenumerate(train_loader):
total_it +=1
img_ir = img_ir.to(device)
img_vi = img_vi.to(device)
label = label.to(device)
bi = bi.to(device).squeeze(1)
bd = bd.to(device).squeeze(1)
vi_Y, vi_Cb, vi_Cr = RGB2YCrCb(img_vi)
vi_Y = vi_Y.to(device)
vi_Cb = vi_Cb.to(device)
vi_Cr = vi_Cr.to(device)
mask = mask.to(device)
seg_pred, bi_pred, bd_pred, fused_img, re_vi, re_ir = multi_task_model(img_vi, img_ir)# seg_pred = F.softmax(seg_pred, dim=1) # seg_pred = multi_task_model(img_vi, img_ir)
optimizer.zero_grad()
seg_loss = Seg_loss(seg_pred, label, device, criteria)
bd = F.one_hot(bd,num_classes=2)
bd=bd.permute(0,3,1,2).float()
bi = F.one_hot(bi,num_classes=2)
bi= bi.permute(0,3,1,2).float()
bd_loss = F.binary_cross_entropy_with_logits(bd_pred, bd)
bi_loss = F.binary_cross_entropy_with_logits(bi_pred, bi, pos_weight=binary_class_weight)
seg_results = torch.argmax(seg_pred, dim=1, keepdim=True)## print(seg_result.shape())
train_seg_loss =10* seg_loss +5* bi_loss +5* bd_loss
## reconstruction-related loss
fusion_loss, ssim_loss, grad_loss, int_loss = Fusion_loss(img_ir, vi_Y, fused_img, mask)
vi_re_loss, vi_int_loss, vi_grad_loss = Re_loss(re_vi, vi_Y, mask=mask, ir_flag=False)
ir_re_loss, ir_int_loss, ir_grad_loss = Re_loss(re_ir, img_ir, mask=mask, ir_flag=True)
train_loss =1* train_seg_loss +1* fusion_loss +0.5* vi_re_loss +0.5* ir_re_loss
train_loss.backward()
optimizer.step()
seg_metric.addBatch(seg_results, label, lb_ignore)# dist.destroy_process_group()if opt.local_rank ==0:
lr = optimizer.get_lr()
mIoU = np.array(seg_metric.meanIntersectionOverUnion().item())
Acc = np.array(seg_metric.pixelAccuracy().item())
end = time()
training_time, glob_t_intv = end - start, end - glob_st
now_it = total_it+1
eta =int((total_epoch *len(train_loader)- now_it)*(glob_t_intv /(now_it)))
eta =str(datetime.timedelta(seconds=eta))
logger.info('ep: [{}/{}], learning rate: {:.6f}, time consuming: {:.2f}s, segmentation loss: {:.4f}, fusion loss: {:.4f}, vi rec loss: {:.4f}, ir rec loss: {:.4f}'.format(ep+1, total_epoch, lr, training_time, seg_loss.item(), fusion_loss.item(), vi_re_loss.item(), ir_re_loss.item()))
logger.info('ssim loss: [{:.4f}], grad loss: [{:.4f}], int loss: [{:.4f}], segmentation loss: {:.4f}, mIou: {:.4f}, Acc: {:.4f}, Eta: {}\n'.format(ssim_loss.item(), grad_loss.item(), int_loss.item(), seg_loss.item(), mIoU, Acc, eta))
start = time()## save Visualization resultsif(ep +1)% opt.img_save_freq ==0and opt.local_rank ==0:input=[img_ir, img_vi, fused_img, label]
fused_rgb = YCbCr2RGB(fused_img, vi_Cb, vi_Cr)
vi_rgb = YCbCr2RGB(re_vi, vi_Cb, vi_Cr)
output =[re_ir, vi_rgb, fused_rgb, seg_results]
saver.write_img(ep,input, output)## save modelif(ep +1)% opt.model_save_freq ==0and opt.local_rank ==0:
test_mIoU = multi_task_tester(test_loader, multi_task_model, device, opt)
logger.info('test mIoU: {:.4f}, best mIoU:{:.4f}'.format(test_mIoU, best_mIou))if test_mIoU > best_mIou:
best_mIou = test_mIoU
saver.write_model(ep, total_it, multi_task_model, optimizer.optim, best_mIou, device)defmulti_task_tester(test_loader, multi_task_model, device, opts):
multi_task_model.eval()
test_bar= tqdm(test_loader)
seg_metric = SegmentationMetric(opts.class_nb, device=device)
lb_ignore =[255]## define save dirwith torch.no_grad():# operations inside don't track history for it,(img_ir, img_vi, label, img_names)inenumerate(test_bar):
img_ir = img_ir.to(device)
img_vi = img_vi.to(device)
label = label.to(device)
Seg_pred, _, _, fused_img, re_vi, re_ir = multi_task_model(img_vi, img_ir)
seg_result = torch.argmax(Seg_pred, dim=1, keepdim=True)## print(seg_result.shape())
seg_metric.addBatch(seg_result, label, lb_ignore)
mIoU = np.array(seg_metric.meanIntersectionOverUnion().item())return mIoU
if __name__ =='__main__':
main()
options.py
import argparse
classTrainOptions():def__init__(self):
self.parser = argparse.ArgumentParser()# data loader related
self.parser.add_argument('--dataroot',type=str, default='/data/timer/Idea/mtan/dataset/MSRS',help='path of data')
self.parser.add_argument('--phase',type=str, default='train',help='phase for dataloading')
self.parser.add_argument('--batch_size',type=int, default=12,help='batch size')
self.parser.add_argument('--nThreads',type=int, default=16,help='# of threads for data loader')# training related
self.parser.add_argument('--lr', default=1e-3,type=int,help='Initial learning rate for training model')
self.parser.add_argument('--weight', default='dwa',type=str,help='multi-task weighting: equal, uncert, dwa')
self.parser.add_argument('--n_ep',type=int, default=1500,help='number of epochs')# 400 * d_iter
self.parser.add_argument('--n_ep_decay',type=int, default=1000,help='epoch start decay learning rate, set -1 if no decay')# 200 * d_iter
self.parser.add_argument('--resume',type=str, default=None,help='specified the dir of saved models for resume the training')# 不要改该参数,系统会自动分配
self.parser.add_argument('--gpu',type=str, default='cuda',help='device id (i.e. 0 or 0,1 or cpu)')
self.parser.add_argument('--temp', default=2.0,type=float,help='temperature for DWA (must be positive)')# ouptput related
self.parser.add_argument('--name',type=str, default='MPF-Trans-skip_DDP',help='folder name to save outputs')
self.parser.add_argument('--class_nb',type=int, default=9,help='class number for segmentation model')
self.parser.add_argument('--display_dir',type=str, default='/data/timer/Idea/mtan/logs',help='path for saving display results')
self.parser.add_argument('--result_dir',type=str, default='/data/timer/Idea/mtan/results',help='path for saving result images and models')
self.parser.add_argument('--display_freq',type=int, default=10,help='freq (iteration) of display')
self.parser.add_argument('--img_save_freq',type=int, default=10,help='freq (epoch) of saving images')
self.parser.add_argument('--model_save_freq',type=int, default=10,help='freq (epoch) of saving models')# DDP related
self.parser.add_argument('--local_rank',type=int, default=0,help='Specifying the default GPU')defparse(self):
self.opt = self.parser.parse_args()
args =vars(self.opt)print('\n--- load options ---')for name, value insorted(args.items()):print('%s: %s'%(str(name),str(value)))return self.opt
classTestOptions():def__init__(self):
self.parser = argparse.ArgumentParser()# data loader related
self.parser.add_argument('--dataroot',type=str, default='/data/timer/Idea/mtan/dataset/MSRS',help='path of data')
self.parser.add_argument('--phase',type=str, default='test',help='phase for dataloading')
self.parser.add_argument('--batch_size',type=int, default=16,help='batch size')
self.parser.add_argument('--nThreads',type=int, default=16,help='# of threads for data loader')## mode related
self.parser.add_argument('--class_nb',type=int, default=9,help='class number for segmentation model')
self.parser.add_argument('--resume',type=str, default='/data/timer/Idea/mtan/results/MPF-skip/best_model.pth',help='specified the dir of saved models for resume the training')
self.parser.add_argument('--gpu',type=int, default=0,help='GPU id')# results related
self.parser.add_argument('--name',type=str, default='MPF_skip',help='folder name to save outputs')
self.parser.add_argument('--result_dir',type=str, default='/data/timer/Idea/mtan/test',help='path for saving result images and models')defparse(self):
self.opt = self.parser.parse_args()
args =vars(self.opt)print('\n--- load options ---')for name, value insorted(args.items()):print('%s: %s'%(str(name),str(value)))return self.opt
一些主要的操作都在train.py文件里有所涉及,因为是第一次系统的使用DDP,还有很多地方理解的不够透彻,不当之处希望大家指出一起交流。
版权归原作者 Timer-419 所有, 如有侵权,请联系我们删除。