Flash注意2:减少GPU内存并加速Transformers

将公共 MCP 服务器部署为 API 端点,并使用函数调用将其工具集成到 LLM 工作流程中。

来源:Clarifai博客 | 实际应用中的人工智能

Flash注意2:减少GPU内存并加速Transformers

简介

Transformer 革命现已深入到长上下文时代。GPT-4(32 k 令牌)、MosaicML 的 MPT(65 k)和 Claude(100 k)等模型可以处理整个章节或代码库。然而,随着上下文的增长,注意力机制成为瓶颈:计算相似度矩阵 S = Q·K^T 和概率矩阵 P = softmax(S) 会产生 N×N 的数据结构。这些矩阵必须在 GPU 的微型片上 SRAM 与其较大但速度较慢的高带宽内存 (HBM) 之间移动,从而消耗带宽并限制吞吐量。在计算 FLOP 持续攀升的世界中,真正的限制已成为内存。

FlashAttention 于 2022 年推出,通过平铺计算以避免存储完整的 S 或 P 矩阵来解决这个问题,提供 2–4 倍的加速和高达 10–20 倍的内存节省。FlashAttention-2 (FA2) 更进一步:它减少了昂贵的非 matmul 操作,跨序列长度并行化,并分区工作以最小化共享内存流量。基准测试显示 FA2 的速度大约是其前身的两倍,比标准注意力实现快九倍,在 NVIDIA A100 GPU 上达到 225 TFLOPs/s。本指南解释了 FA2 的工作原理、何时使用它、如何将其集成到您的堆栈中以及它的局限性。

快速摘要

  • FA2 解决了内存限制问题。Attention 的 N² 内存占用会导致 GPU 停顿;平铺和内核融合将其降低到线性内存成本。
  • 主要创新:更少的非 matmul FLOP、沿序列长度的额外并行性以及跨扭曲切片查询矩阵。
  • 采用:支持 Ampere/Ada/Hopper GPU 和 FP16/BF16 数据类型。通过 pip 安装并在 PyTorch 或 Hugging Face 中翻转标志以启用。
  • 谁受益:任何训练或服务长上下文模型(8 k-16 k token)或使用大头部尺寸的人;节省的成本是巨大的。
  • 变形金刚中的内存瓶颈

    为什么内存(而不是计算)很重要

    专家见解

    快速摘要

    注意事项