ParaRNN:解锁大型语言模型的非线性 RNN 并行训练

循环神经网络 (RNN) 为序列建模奠定了基础,但其内在的序列性质限制了并行计算,为扩展造成了根本障碍。这导致了 Transformer 等可并行架构以及最近的状态空间模型 (SSM) 的主导地位。虽然 SSM 通过结构化线性递归实现高效并行化,但这种线性约束限制了它们的表达能力,并妨碍对复杂的非线性序列依赖关系进行建模。为了解决这个问题,我们提出了 ParaRNN,一个打破......

来源:Apple机器学习研究

循环神经网络 (RNN) 为序列建模奠定了基础,但其内在的序列性质限制了并行计算,为扩展造成了根本障碍。这导致了 Transformer 等可并行架构以及最近的状态空间模型 (SSM) 的主导地位。虽然 SSM 通过结构化线性递归实现高效并行化,但这种线性约束限制了它们的表达能力,并妨碍对复杂的非线性序列依赖关系进行建模。为了解决这个问题,我们提出了 ParaRNN,一个打破非线性 RNN 序列并行化障碍的框架。在先前工作的基础上,我们将非线性递推关系序列转换为单个方程组,并使用牛顿迭代与自定义并行约简相结合来并行求解。我们的实现比简单的顺序应用程序实现了高达 665 倍的加速,允许以前所未有的规模训练非线性 RNN。为了展示这一点,我们将 ParaRNN 应用于 LSTM 和 GRU 架构的改编,成功训练了 7B 参数的模型,其复杂度可与类似大小的 Transformer 和 Mamba2 架构相媲美。为了加速高效序列建模的研究,我们发布了 ParaRNN 代码库作为非线性 RNN 自动训练并行化的开源框架,使研究人员和从业者能够大规模探索新的非线性 RNN 模型。