PyTorch NaN 是无声杀手 - 因此我构建了一个 3ms Hook 来在精确层捕获它们

NaN 不会破坏你的训练——它们会悄悄地破坏它。在 ResNet 训练运行中因无声故障而损失了几个小时后,我构建了一个轻量级检测器,可以精确定位出现问题的确切层和批次。使用前向钩子和梯度检查,它可以以最小的开销尽早发现问题,而不会减慢模型的速度。PyTorch NaNs 是无声杀手——所以我构建了一个 3ms 的钩子来在精确层捕获它们,该文章首先出现在《走向数据科学》上。

来源:走向数据科学
  • NaN 并不起源于它们出现的地方——它们默默地跨层传播
  • torch.autograd.set_detect_anomaly 太慢并且经常误导真正的调试
  • 基于前向钩子的检测器可以在确切的层捕获 NaN,并捕获它们第一次出现的批次
  • 每次前向传递的开销约为 3–4 毫秒,远低于异常检测(尤其是在 GPU 上)
  • 在大多数情况下,梯度爆炸是真正的根本原因——尽早发现它可以完全防止 NaN
  • 系统记录结构化事件(层、批次、统计)以进行精确调试
  • 专为生产而设计:线程安全、内存有限且可扩展
  • 批量为 47,000。我已经在自定义医学成像数据集上训练了六个小时的 ResNet 变体。损失完全收敛——1.4、1.1、0.87、0.73——然后就什么也没有了。不是错误。不是崩溃。只是楠。

    我添加了 torch.autograd.set_detect_anomaly(True) 并重新启动。训练速度慢得像爬行一样——仅在 CPU 上每批的时间就长了大约 7-10 倍——三个小时后,我终于得到了一个堆栈跟踪,指向一个坦率地说看起来不错的层。真正的罪魁祸首是学习率调度程序与上游两层的自定义标准化层交互不良。set_detect_anomaly 向我指出了症状,而不是根源。

    那次调试花费了我一天的大部分时间。所以我做了一些更好的东西。

    NaN 不会破坏你的模型——它们会悄悄地破坏它。当您注意到时,您已经在调试错误的层了。

    完整代码:https://github.com/Emmimal/pytorch-nan- detector/

    set_detect_anomaly 的问题

    PyTorch 附带 torch.autograd.set_detect_anomaly(True),这是调试 NaN 问题的标准建议。它的工作原理是保留完整的计算图并在向后传递期间检查异常情况。这很强大,但它带来了严重的成本,使其除了快速本地健全性检查之外不适合任何其他用途。

    我的基准测试,在小型 CPU MLP (64→256→256→10) 上运行,测得:

    方法:前钩

    实施

    组件 3:有限内存