1. 反直觉的bfloat16
torch支持单精度浮点数bfloat16。这种数据类型在使用的时候需要格外小心,因为它很可能会表现出一系列的“反人类直觉”特性。
什么是bfloat16
BF16是brain float的简称(来源于google brain)。不同于普通的单精度浮点数FP16(i.e., torch.float16),BF16是介于FP16和FP32之间的一种浮点数格式。BF16的指数位比FP16多,跟FP32一样,不过小数位比较少。即,BF16尝试采用牺牲精度的方法,来换取更大的数值空间(Dynamic Range)。
bfloat16带来的问题
虽然有实验和研究都已经表明,BF16的这种“牺牲精度”并不会损害性能 (甚至某些情况下能带来性能提升),并且速度更快,内存消耗更少 (和FP16一样)。但是在实际使用过程中,它往往还是会带来很多的模型训练时的负面影响,如:
- 混合精度训练时,loss出现
NAN
或INF
- 巨大的数值间隙,令人费解
比方说,下面这张图,计算constractive loss的时候,需要把positive loss (一个较大的数值)和negative loss (一个较小的数值)相加,就会由于BF16的小数精度表达能力过弱,而导致negative loss根本不起效果:
再比如下面这张图,也非常容易让人confused:
所以,在某些特定场景和需求下,我们可以选择不用BF16,而使用传统的FP16。比方说,计算constractive loss,对于小数数值精度有一定要求。
但很遗憾,目前很多与训练语言模型,比方说T5,都是使用BF16进行预训练的…
计算机体系结构的知识忘的太多了,现在已经有点记不得浮点数的相关概念了。。。
解决方案
参考自:float16 vs bfloat16 numerical properties comparison
主要有两种策略来应对bfloat16的精度问题:
- 添加
assert
: 比较保守的一种策略。如果实在不确定自己的代码当前bfloat16计算是不是已经出现了浮点数溢出问题,可以在代码里面添加断言,来检测模型训练过程中 (尤其是计算loss的时候),tensor中是否已经出现了overflow现象。例如:
assertnot torch.isinf(loss).any().item()andnot torch.isnan(loss).any().item()
- 禁用pytorch subnormal numbers 如果需要使用bfloat生成高精度的tensor,用下面这段代码把torch的denormal操作关闭:
_ = torch.set_flush_denormal(False)
示范:
- 强行更改
dtype
最暴力的一种解决精度问题的策略,自然就是强行把bfloat的矩阵类型改为双精度。适用于需要较高精度支持的场景:
2. 参考:
- Mixed precision for bfloat16-pretrained models
- 深度学习与bfloat16(BF16)
- What Every User Should Know About Mixed Precision Training in PyTorch
- float16 vs bfloat16 numerical properties comparison
版权归原作者 Reza. 所有, 如有侵权,请联系我们删除。