什么是分组查询注意力 (GQA)?

作者

Dave Bergmann

Senior Staff Writer, AI Models

IBM Think

Cole Stryker

Staff Editor, AI Models

IBM Think

什么是分组查询注意力 (GQA)?

分组查询注意力 (GQA) 是一种提高转换器模型注意力机制效率的方法,通常用于提高大型语言模型 (LLM) 的推理速度。

Ainslie 等人将分组查询注意力设想为多头注意力 (MHA) 的优化,MHA 是 2017 年发表的开创性论文《Attention is All You Need》中提出的创新自注意力算法,它建立了转换器神经网络。更具体地说,它被认为是对多查询注意力 (MQA) 的泛化和更有限制的应用,MQA 是 MHA 的一种早期优化。

虽然标准多头注意力催化了机器学习自然语言处理 (NLP)生成式 AI 的飞跃发展,但它对计算资源和内存带宽的要求极高。随着 LLM 越来越大、越来越复杂,这些内存使用要求成为制约进步的瓶颈,尤其是对于自回归仅解码器 LLM 而言(用于文本生成摘要和其他生成式 AI 任务)。

后续的研究集中在增强或简化多头注意力的技术上。其中一些注意力机制(如 Flash 注意力和 Ring 注意力)改进了用于训练和运行模型的 GPU 处理计算和内存存储的方式。其他系统,例如 GQA 和 MQA,则探索了转换器架构处理词元的方式的改变。

分组查询注意力旨在平衡标准多头注意力和多查询注意力之间的权衡。前者以增加内存带宽开销和降低速度为代价,最大限度地提高了准确性。后者则以牺牲准确性为代价,最大限度地提高速度和效率。

专家为您带来最新的 AI 趋势

获取有关最重要且最有趣的 AI 新闻的精选洞察分析。订阅我们的每周 Think 时事通讯。请参阅 IBM 隐私声明

谢谢!您已订阅。

您的订阅将以英语提供。您会在每份时事通讯中找到一个取消订阅链接。您可以在此处管理您的订阅或取消订阅。有关更多信息,请参阅我们的 IBM 隐私声明

标准多头注意力

要了解分组查询注意力如何优化转换器模型,请务必首先了解多头注意力的一般工作原理。GQA 和 MQA 都只是简单地改进而不是取代 MHA 的核心方法。

LLM 和其他使用转换器架构的模型背后的驱动力是自注意力,这是一个用于理解序列中不同词元之间关系的数学框架。自注意力允许 LLM 不仅通过静态基线定义来解释文本数据,还可以通过其他单词和短语提供的上下文来解释文本数据。

在用于文本生成的自回归 LLM 中,注意力机制通过确定当时哪些先前词元最值得“关注”,来帮助模型预测序列中的下一个词元。它认为最相关的词元信息会被赋予更大的关注权重,而认为不相关的词元信息则会被赋予接近 0 的关注权重。

多头注意力机制通过将注意力层分成多个注意力头,并行多次计算自注意力,为转换器模型生成丰富的上下文信息。

多头注意力图 在《Attention is All You Need》中广为人知的简化多头注意力图

标准多头注意力的工作原理

《Attention is All You Need》的作者使用关系数据库的术语:查询,阐明了其注意力机制。关系数据库旨在简化相关数据的存储和检索:它们为每条数据分配一个唯一标识符(“键”),并且每个都与相应的关联。关系数据库的目标是将每个查询与适当的键相匹配。

对于序列中的每个词元,多头注意力需要创建 3 个向量。

  • 查询向量 Q,表示词元正在“查找”的信息。例如,名词的查询向量可能表示对描述它的形容词的搜索。

  • 键向量 K 表示词元所包含的信息。对齐分数代表每个词元的键向量与其他每个词元的查询向量的相关性,用于计算注意力权重。

  • 值向量 V 表示上下文信息,该信息将由来自其他词元的键向量的注意力加权贡献进行更新。

在注意力机制的调节下,这 3 个向量之间的数学互动,就是模型如何调整其对每个词元的特定上下文理解的过程。
 

生成查询、键和值矢量

要为给定的词元生成这 3 个向量中的每一个,模型都要从该词元的原始向量嵌入开始:这是一种数字编码,其中向量的每个维度都与词元语义的某些抽象元素相对应。这些向量的维数是一个预先确定的超参数。

每个词元的 QKV 向量是通过经第一个注意力层之前的线性层传递原始词元嵌入生成的。该线性层被划分为 3 个独特的模型权重矩阵:WQ、WK 和 WV。其中的具体权重值是通过对海量文本示例数据集进行自监督预训练来学习的。

将词元的原始向量嵌入向量乘以 WQ、WK 和 WV,分别得到对应的查询向量、键向量和值向量。每个向量包含的维数 d 由每个权重矩阵的大小决定。QK 将具有相同的维数 dk

然后,这 3 个向量会被传递到注意力层。

转换器模型注意力机图 转换器的注意力机制简图:输入句子中令牌的原始矢量嵌入分别乘以 W、K 和 V 权重矩阵,得到各自的 W、K 和 V 矢量。
缩放点积注意力和 softmax

在注意力层中,Q、KV 向量用于计算序列中每个位置每个词元之间的对齐分数。然后使用 softmax 函数将这些对齐分数归一化为注意力权重

对于序列中的每个词元 x,对齐分数是通过计算该词元的查询向量 Qx 与每个其他词元的键向量 K点积(即相乘)来求得的。如果两个词元之间有意义的关系体现为它们各自向量之间的相似性,则将它们相乘会产生一个很大的值。如果两个向量没有对齐,则将它们相乘会产生一个小值或负值。大多数转换器模型使用一种称为缩放点积注意力的变体,其中 QK 被缩放(即乘以)1dk来提高训练稳定性。

然后将这些查询-键对齐分数输入到 softmax 函数中。Softmax 将所有输入归一化为 0 到 1 之间的值,使得它们之和为 1。softmax 函数的输出是注意力权重,每个权重代表词元 x 分给其他每个词元的注意力份额(总值为 1)。如果一个词元的注意力权重接近 0,则会被忽略。注意力权重为 1 意味着一个词元将获得 x 的全部注意力,而所有其他词元将被忽略。

最后,将每个词元的值向量乘以其注意力权重。这些来自前一个词元的注意力加权贡献被求和后得出平均值,并添加到词元 x 的原始向量嵌入中。这样,词元 x 的嵌入就得到了更新,以反映序列中与之相关的其他词元所提供的上下文。

然后,更新后的向量嵌入被发送到另一个线性层,该层具有自己的权重矩阵 WZ,其中上下文更新的向量被归一化回一致的维数,然后发送到下一个注意力层。每一个渐进的注意力层都能捕捉到更多细微的背景信息。
 

多注意力头

使用来自其他词元的注意力加权贡献的平均值,而不是单独计算每个注意力加权上下文,在数学上是有效的,但会导致细节丢失。

为了补偿,转换器网络将原始输入的词元嵌入分成 h 个大小均匀的块。其同样将 WQ、WK 和 WV 拆分为 h 个子集,分别称为查询头键头值头。每个查询头、键头和值头都会收到一段原始词元嵌入。每个并行的查询头、键头和值头三元组产生的向量都被输入到相应的注意头中。最终,这 h 个并行回路的输出被连接在一起以更新完整的词元嵌入。

多头注意力的连接 每个注意力头的输出"Z" 被串联在一起。在此示例中,h = 8。

在训练过程中,每个回路都会学习不同的权重,以捕捉语义的不同方面。这反过来又有助于模型处理一个单词的含义受周围其他单词的语境影响的不同方式。

多头注意力模块图 多头注意力模块(h = 8)中所有矩阵乘法的简化图。源自 Jay Alammar 的“图解转换器”。请注意,“+”表示连接,而不是加法。

标准多头注意力的缺点

标准多头注意力的缺点并不在于存在某些关键缺陷,而是缺乏任何优化。MHA 是同类算法中的首创,代表了其注意力计算一般机制中最复杂的执行方式。

MHA 的低效主要源于大量的计算和模型参数。在标准 MHA 中,每个注意力模块中的每个查询头、键头和值头都有自己的权重矩阵。因此,举例来说,一个每个注意力层有 8 个注意力头的模型(远远少于大多数现代 LLM)仅该层的 Q、K 和 V 注意力头就需要 24 个独特的权重矩阵。这就需要在每一层进行大量的中间计算。

这种配置的一个后果是计算成本昂贵。计算 MHA 的要求相对于序列长度呈二次方增长:将输入序列中的词元数量加倍需要将复杂度增加四倍。这对上下文窗口的大小设置了严格的实际限制。

MHA 还会给系统内存带来很大压力。GPU 没有太多的板载内存来存储大量中间计算的输出,这些中间计算在每个后续处理步骤中都必须调用。相反,这些中间结果存储在高带宽内存 (HBM) 中,而该内存并不位于 GPU 芯片本身上。每次必须从内存读取密钥和值时,这都需要少量的延迟。随着转换器模型开始扩展到数十亿个参数,训练和运行推理所需的时间和计算成为模型的性能瓶颈。

要取得进一步的进展,就必须在不降低转换器学习和再现错综复杂的语言模式的能力的前提下,减少计算步骤的数量。MQA 和随后的 GQA 就是在这种情况下推出的。

 

多查询注意力 (MQA) 的工作原理

多查询注意力 (MQA) 是一种计算效率更高的注意力机制,它简化了多头注意力,以减少内存使用和中间计算。MQA 不是为每个注意力头训练唯一的键头和值头,而是在每一层使用单个键头和单个值头。因此,键向量和值向量只计算一次;然后,所有 h 个注意头共享这一组键向量和值向量。

这种简化大大减少了模型必须计算并存储在高带宽内存中的线性投影的数量。根据 2019 年引入 MQA 的论文,MQA 使键值对存储(或 KV 缓存)缩减到 10-100 分之一,将解码器推理速度提高 12 倍。MQA 减少的内存使用量还通过启用更大的批大小显著加快了训练速度。

分组查询注意力图

多查询注意力 (MQA) 的缺点

尽管 MQA 有其优点,但也有一些不可避免的缺点。

  • 性能下降:减少独特的可训练模型参数的数量会降低模型的知识容量和细微差别,这一点不足为奇。与标准 MHA 相比,MQA 的准确性大幅下降,因此不适合某些情况和用例。

  • 必须从头开始训练:用标准 MHA 训练的模型不能简单地适应 MQA,而必须从头开始用 MQA 训练。这意味着 MQA 不能用于优化现有模型,并且在对新模型进行 MQA 试验时会产生相当大的机会成本。

  • 张量并行性中的冗余:在 GPU 上训练转换器模型的主要优点之一是能够并行执行多个复杂的张量运算。执行这些运算的 GPU 集群的每个节点上都必须存在 K 和 V 值,这意味着实际上必须为每个节点复制这两个值。尽管这仍比标准 MHA 更高效,但它却并非计算资源的最佳使用方式。

分组查询注意力 (GQA) 的工作原理

分组查询注意力是多查询注意力的一种更通用、更灵活的形式,它将查询头分为多个,每个组共享一组键和值,而不是所有查询头共享一组键和值。

在 2023 年 5 月《GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints》发布后,许多 LLM 迅速采用了 GQA。例如,Meta 于 2023 年 7 月首次为其 Llama 2 模型采用 GQA,并在 2024 年发布的 Llama 3 模型中保留了 GQA。Mistral AI 在其 2023 年 9 月发布的 Mistral 7B 模型中使用了 GQA。同样,IBM 的 Granite 3.0 模型采用 GQA 进行快速推理。

分组查询注意力与多查询注意力与多头注意力

理论上,GQA 可以看作是标准 MHA 到完全 MQA 之间连续谱的泛化。键值头组数与注意力头数相同的 GQA 相当于标准 MHA;头组数为 1 的 GQA 相当于 MQA。

在实践中,GQA 几乎总是意味着某种中间方法,其中组的数量本身就是一个重要的超参数。

分组查询注意力图

分组查询注意力的优点

分组查询注意力具有几个优势,因此在主要的 LLM 中得到了相对广泛的采用。

  • 高效使用 GPU:GQA 的键值对分发利用了张量并行性,从而减少了因复制冗余值而“浪费”的计算量。

  • 有效的折衷方案:GQA 在解码器推理速度和性能准确性之间提供了理想的权衡,其准确性几乎与 MHA 相当,而速度几乎与 MQA 相当。

  • 降低内存带宽开销:与 MQA 一样,GQA 显著减少了在推理时必须计算、存储和检索的中间计算的数量。

  • 灵活的训练:与 MQA 不同的是,分组查询关注不需要使用这种方法从头开始训练模型。使用标准 MHA 预训练的模型可通过一种称为“向上训练”的微调过程来适应 GQA。
Mixture of Experts | 12 月 12 日,第 85 集

解码 AI:每周新闻摘要

加入我们世界级的专家小组——工程师、研究人员、产品负责人等将为您甄别 AI 领域的真知灼见,带来最新的 AI 资讯与深度解析。

相关解决方案
IBM watsonx.ai

使用面向 AI 构建器的新一代企业级开发平台 IBM watsonx.ai,可以训练、验证、调整和部署生成式 AI、基础模型和机器学习功能。使用一小部分数据,即可在很短的时间内构建 AI 应用程序。

了解 watsonx.ai
人工智能 (AI) 解决方案

借助 IBM 业界领先的人工智能专业知识和解决方案组合,让人工智能在您的业务中发挥作用。

深入了解 AI 解决方案
AI 咨询与服务

通过增加 AI 重塑关键工作流程和运营,最大限度提升体验、实时决策和商业价值。

深入了解人工智能服务
采取后续步骤

一站式访问跨越 AI 开发生命周期的功能。利用用户友好型界面、工作流并访问行业标准 API 和 SDK,生成功能强大的 AI 解决方案。

深入了解 watsonx.ai 预约实时演示