Pytorch-Lightning中的训练器—Trainer
Trainer()
常用参数
参数名称含义默认值接受类型callbacks添加回调函数或回调函数列表None(ModelCheckpoint默认值)Union[List[Callback], Callback, None]enable_checkpointing是否使用callbacksTrueboolgpus使用的gpu数量(int)或gpu节点列表(list或str)None(不使用GPU)Union[int, str, List[int], None]precision指定训练精度32(full precision)Union[int, str]default_root_dir模型保存和日志记录默认根路径None(os.getcwd())Optional[str]logger设置日志记录器(支持多个),若没设置logger的save_dir,则使用default_root_dirTrue(默认日志记录)Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool]max_epochs最多训练轮数(指定为**-1可以设置为无限次**)None(1000)Optional[int]min_epochs最少训练轮数None(1)Optional[int]max_steps最大网络权重更新次数-1(禁用)Optional[int]min_steps最少网络权重更新次数None(禁用)Optional[int]weights_save_path权重保存路径(优先级高于default_root_dir),ModelCheckpoint未定义路径时将使用该路径None(default_root_dir)Optional[str]log_every_n_steps更新n次网络权重后记录一次日志50intauto_scale_batch_size自动搜索最佳batch_size并保存到模型的self.bacth_size中FalseUnion[str, bool]auto_lr_find自动搜索最佳学习率并存储到self.lr或self.learing_rateFalseUnion[str, bool]accumulate_grad_batches每k次batches累计一次梯度NoneUnion[int, Dict[int, int], None]check_val_every_n_epoch每n个train epoch执行一次验证1intnum_sanity_val_steps开始训练前加载n个验证数据进行测试,k=-1时加载所有验证数据2int
额外的解释
- 这里max_steps/min_steps中的step就是指的是优化器的step,优化器每step一次就会更新一次网络权重
- **梯度累加(Gradient Accumulation)**:受限于显存大小,一些训练任务只能使用较小的batch_size,但一般batch-size越大(一定范围内)模型收敛越稳定效果相对越好;梯度累加可以先累加多个batch的梯度再进行一次参数更新,相当于增大了batch_size。
Trainer.fit()
常用参数
参数名称含义默认值modelLightningModule实例****train_dataloaders训练数据加载器Noneval_dataloaders验证数据加载器Noneckpt_pathckpt文件路径(从这里文件恢复训练)Nonedatamodule****LightningDataModule实例None
ckpt_path参数详解(从之前的模型恢复训练)
使用该参数指定一个模型ckpt文件(需要保存整个模型,而不是仅仅保存模型权重),Trainer将从ckpt文件的下一个epoch继续训练。
示范
net = MyNet(...)
trainer = pl.Trainer(...)# 假设模型保存在./ckpt中
trainer.fit(net, train_iter, val_iter, ckpt_path='./ckpt/myresult.ckpt')
使用注意
- 请不要使用Trainer()中的resume_from_checkpoint参数,该参数未来将被丢弃,请使用Trainer.fit()的ckpt_path参数
Trainer.test()
常用参数
参数名称含义默认值modelLightningModule实例None(使用fit()传递的模型)verbose是否打印测试结果Truedataloaders测试数据加载器(可以使用torch.utils.data.DataLoader)Noneckpt_pathckpt文件路径(从这里文件恢复训练)Nonedatamodule****LightningDataModule实例None
版权归原作者 奈何桥边摆地摊 所有, 如有侵权,请联系我们删除。