nan报错

2022年01月16日 阅读数:1
这篇文章主要向大家介绍nan报错,主要内容包括基础应用、实用技巧、原理机制等方面,希望对大家有所帮助。


前言


  • 系统环境:torch 1.9.1
  • 遇到​​nan​​ 的主要排查顺序:(1)先检查主要代码有没有写错?(2)接着看会不会再log,exp()的时候出现计算问题?(3)梯度大小设置是否合理?

1. 场景1

前几天在写一段代码的时候,遇到了​​nan​​​错误,可是一直么有找到问题所在。个人代码主要是将一个 ​​batch​​​ 中的entity对应的 ​​embedding​​ 送到自定义的2层BertModel中。git

缘由


  • 送入到BertModel中仅传入了 hidden_states,没有传入attention_mask。可是后来搞定这个bug以后仍是有nan值的村子。
  • 因而按行debug,发现问题在于送入到BertEncoder中的 hidden_states 中有nan的存在。
    咱们能够用下面这行代码检查tensor中是否有nan 值。
    ​torch.any(torch.isnan(entity_bank))​​ 这个 ​​entity_bank​​ 即是一个待检测的向量。因而对应解决这个nan值的生成便可。
    但奇怪的是,使用下图中注释的一行就会有nan 的问题,可是用第二行代码则没有这个问题。
    nan报错_pytorch

2. 场景2

最近在写一个自定义的损失函数,形式以下:ide

nan报错_参数说明_02

其中的L1 是正样本的损失,L2是负样本的损失。函数

缘由

训练过程当中再次出现nan,后来发现缘由是:上式中的 ​​log(*)​​ 中的值可能存在0,我是先对0求了log,而后过滤掉了inf值,可是事实证实这么作是不行的。最好的方式是:学习

eps=1e-7
log(x+eps)

从而避免nan值的产生。spa

3. 场景3

nan最近真的是粘上我了 ┭┮﹏┭┮debug

短短几天,碰到了3次nan报错,并且每次都不同。今天(2021/12/18)此次碰到的状况是模型训练到了必定的epoch时就开始出现​​nan​​。以下所示:code

nan报错_参数说明_03

由于已经到了第50epoch,因此我就怀疑是梯度的问题了。若是是脏数据的问题,那么应该在一个epoch以内就会报错。从网上查阅资料,感受像是由于梯度爆炸致使的nan错,因而我调整了一个动态学习率,就没有再报这个错了。blog

若是是本身手写的loss,那么必定要保证最后获得的loss是一个均值,要么是除以​​batch_size​​,要么是除以​​样本条数​​,由于这个系数在求导的时候,就会被带到其中做为梯度的系数,这样就会把梯度降下来。get

4. 场景4

由于代码bug致使的nan。it

我在使用 R-Drop 的时候,须要使用KL散度求两个 ​​predict logit​​ 间的距离,pytorch中的 kl_div 函数有个参数说明以下:

nan报错_pytorch_04

这个参数:表示传入的 ​​target​​ 是不是在​​log​​空间下(便是否由log处理过),推荐传入固定的分布结果,(如​​softmax​​)避免由log致使的数值问题。

5. 场景5

还有一种常见的nan报错是 除零致使的错。以下所述:

背景:同一个entity能够对应多个mention,可是我想取这个entity的表示,就能够经过取各个mention表示的平均来获取。最后除以这个mention的个数便可。

因而我这么写了代码:

cur_entity_emb = torch.mm(cur_entity2mention,x) / torch.sum(cur_entity2mention,dim=1).unsqueeze_(1).expand(-1,808) # 拿到整个entity的表示

获得的这个 ​​cur_entity_emb​​​ 中存在nan,那么罪魁祸首应该就是后面的这个 ​​torch.sum()​​存在问题,果不其然,查了一下的确是这个的问题。