L’attention de requête groupée (GQA) est une méthode permettant d’améliorer l’efficacité du mécanisme d’attention dans les modèles de transformeur. Elle est souvent utilisée pour accélérer l’inférence à partir de grands modèles de langage (LLM).
Ainslie et al. ont conçu l’attention par requêtes groupées comme une optimisation de l’attention multitête (MHA), l’algorithme innovant d’auto-attention introduit dans l’article fondateur de 2017 « Attention is All You Need » qui a posé les bases des réseaux neuronaux transformeurs. Plus précisément, le concept a été proposé comme une généralisation et une application plus restreinte de l’attention multirequête (MQA), une optimisation antérieure du MHA.
Bien que l’attention multitête standard ait provoqué un bond en avant du machine learning, du traitement automatique du langage naturel (NLP) et de l’IA générative, elle est extrêmement gourmande en ressources informatiques et en bande passante mémoire. Au fur et à mesure que les LLM ont gagné en taille et en sophistication, ces besoins de mémoire ont ralenti les progrès, en particulier pour les LLM autorégressifs à décodeur uniquement utilisés dans la génération de texte, la synthèse et d’autres tâches d’IA générative.
Les recherches ultérieures se sont concentrées sur les techniques permettant d’améliorer ou de rationaliser l’attention multitête. Certaines, telles que l’attention éclair et l’attention en anneau, améliorent la façon dont les GPU utilisés pour entraîner et exécuter les modèles gèrent les calculs et le stockage en mémoire. D’autres, comme la GQA et la MQA, modifient la façon dont les architectures transformatrices traitent les tokens.
L’attention par requêtes groupées vise à équilibrer les compromis entre l’attention multitête standard et l’attention multi-requêtes. Le premier optimise la précision au prix d’une augmentation de la bande passante mémoire et d’une diminution de la vitesse. Ce dernier maximise la vitesse et l’efficacité au détriment de la précision.
Pour comprendre comment l’attention de requête groupée optimise les modèles de transformeur, il est important de comprendre comment fonctionne l’attention multitête en général. La GQA et la MQA ne font qu’affiner la méthodologie de base de la MHA ; elles ne la remplacent pas.
Le moteur des LLM et d’autres modèles qui utilisent l’architecture transformatrice est l’auto-attention, un cadre mathématique permettant de comprendre les relations entre chacun des différents tokens d’une séquence. L’auto-attention permet à un LLM d’interpréter des données textuelles non seulement à travers des définitions de base statiques, mais aussi grâce au contexte fourni par d’autres mots et expressions.
Dans les LLM autorégressifs utilisés pour la génération de texte, le mécanisme d’attention aide le modèle à prédire le token suivant dans une séquence en déterminant quels tokens précédents méritent le plus d’être « pris en compte » à ce moment-là. Les informations provenant des tokens qu’il juge les plus pertinents se voient accorder des poids d’attention plus élevés, tandis que les informations provenant de tokens jugés non pertinents reçoivent des poids d’attention proches de 0.
Le mécanisme d’attention multitête qui anime les modèles de transformeur génère des informations contextuelles riches en calculant l’auto-attention plusieurs fois en parallèle par division des couches d’attention en plusieurs têtes d’attention.
Les auteurs de l’article « Attention is All You Need » ont présenté leur mécanisme d’attention en utilisant la terminologie d’une base de données relationnelle : requêtes, clés et valeurs. Les bases de données relationnelles sont conçues pour simplifier le stockage et la récupération des données pertinentes : elles attribuent un identifiant unique (une « clé ») à chaque donnée et chaque clé est associée à une valeur correspondante. L’objectif d’une base de données relationnelle est d’associer chaque requête à la clé appropriée.
Pour chaque token d’une séquence, l’attention multitête nécessite la création de trois vecteurs :
Les interactions mathématiques entre ces trois vecteurs, médiatisées par le mécanisme d’attention, permettent à un modèle d’ajuster sa compréhension contextuelle spécifique de chaque token.
Afin de générer chacun de ces trois vecteurs pour un token donné, le modèle commence par le plongement vectoriel d’origine de ce token : un encodage numérique dans lequel chaque dimension du vecteur correspond à un élément abstrait de la signification sémantique du token. Le nombre de dimensions présentes dans ces vecteurs est un hyperparamètre prédéterminé.
Les vecteurs Q, K et V de chaque token sont générés en faisant passer le plongement du token d’origine à travers une couche linéaire qui précède la première couche d’attention. Cette couche linéaire est divisée en trois matrices uniques de pondérations de modèle : WQ, WK et WV. Les valeurs de pondération spécifiques qui y figurent sont apprises grâce à un pré-apprentissage auto-supervisé sur un jeu de données massif composé d’exemples de texte.
En multipliant le plongement vectoriel d’origine du token par WQ, WK et WV, on obtient respectivement le vecteur de requête, le vecteur de clé et le vecteur de valeur correspondants. Le nombre de dimensions d que contient chaque vecteur est déterminé par la taille de chaque matrice de poids. Q et K ont ainsi le même nombre de dimensions, dk.
Ces trois vecteurs sont ensuite transmis à la couche d’attention.
Dans la couche d’attention, les vecteurs Q, K et V sont utilisés pour calculer un score d’alignement entre chaque token à chaque position d’une séquence. Ces scores d’alignement sont ensuite normalisés en poids d’attention à l’aide d’une fonction softmax.
Pour chaque token x d’une séquence, les scores d’alignement sont calculés en multipliant le produit scalaire du vecteur de requête Qx de ce token par le vecteur de clé K de chacun des autres tokens. Si une relation significative entre deux tokens se traduit par des similitudes entre leurs vecteurs respectifs, la multiplication de ces vecteurs donnera une valeur élevée. Si les deux vecteurs ne sont pas alignés, leur multiplication donnera une valeur faible ou négative. La plupart des modèles de transformeur utilisent une variante appelée attention par produit scalaire dimensionné, dans laquelle QK est dimensionné (c’est-à-dire multiplié) par afin d’améliorer la stabilité de l’entraînement.
Ces scores d’alignement des clés de requête sont ensuite saisis dans une fonction softmaxqui normalise toutes les entrées à une valeur comprise entre 0 et 1, de manière à ce que leur somme soit égale à 1. Les résultats de la fonction softmax sont les poids d’attention, chacun représentant la part (sur 1) de l’attention du token x à accorder à chacun des autres tokens. Si le poids d’attention d’un token est proche de 0, il sera ignoré. Un poids d’attention de 1 signifie qu’un token reçoit toute l’attention de x et que tous les autres sont ignorés.
Enfin, le vecteur de valeur de chaque token est multiplié par son poids d’attention. Une moyenne des contributions pondérées par l’attention de chaque token précédent est calculée et ajoutée au plongement vectoriel d’origine du token x. Avec cela, le plongement du token x est maintenant mis à jour pour refléter le contexte fourni par les autres tokens de la séquence pertinents.
Le plongement vectoriel mis à jour est ensuite envoyé à une autre couche linéaire, avec sa propre matrice de poids WZ, où le vecteur mis à jour en contexte est renormalisé avec un nombre cohérent de dimensions, puis envoyé à la couche d’attention suivante. Chaque couche d’attention progressive capture une plus grande nuance contextuelle.
Il est mathématiquement efficace d’utiliser les moyennes des contributions pondérées par l’attention des autres tokens au lieu de tenir compte de chaque élément de contexte pondéré par l’attention individuellement, mais cela entraîne une perte de détails.
Pour compenser, les réseaux de transformeurs divisent le plongement du token d’entrée d’origine en h morceaux de taille égale. Ils divisent également WQ, WK et WV en h sous-ensembles appelés respectivement têtes de requête, têtes de clé et têtes de valeur. Chaque tête de requête, tête de clé et tête de valeur reçoit un morceau du plongement du token d’origine. Les vecteurs produits par chacun de ces triplets parallèles de têtes d’interrogation, de têtes de clé et de têtes de valeur sont introduits dans une tête d’attention correspondante.Enfin, les résultats de ces h circuits parallèles sont reconcaténés pour mettre à jour le plongement de token complet.
Au cours de l’entraînement, chaque circuit apprend des poids distincts qui capturent un aspect séparé des significations sémantiques. Cela permet au modèle de traiter les différentes façons dont les implications d’un mot peuvent être influencées par le contexte des mots qui l’entourent.
L’inconvénient de l’attention multitête standard n’est pas tant la présence d’un défaut crucial que l’absence d’optimisation. Le MHA a été le premier algorithme de ce type et représente l’exécution la plus complexe de son mécanisme général de calcul de l’attention.
L’inefficacité de l’attention multitête (MHA) provient en grande partie de l’abondance de calculs et de paramètres de modèle. Dans la MHA standard, chaque tête de requête, tête de clé et tête de valeur comprise dans chaque bloc d’attention possède sa propre matrice de poids. Ainsi, par exemple, un modèle comportant huit têtes d’attention dans chaque couche d’attention (ce qui est bien moins que la plupart des LLM modernes) nécessiterait 24 matrices de poids uniques pour les seules têtes Q, K et V de la couche. Cela suppose un nombre considérable de calculs intermédiaires à chaque couche.
L’une des conséquences de cette configuration est qu’elle est coûteuse en matière de calcul. Les exigences de calcul de la MHA varient quadratiquement par rapport à la longueur de la séquence : le doublement du nombre de tokens d’une séquence d’entrée nécessite de quadrupler la complexité. Cela impose des limites pratiques strictes à la taille des fenêtres de contexte.
La MHA met également à rude épreuve la mémoire système. Les GPU n’ont pas beaucoup de mémoire intégrée pour stocker les résultats de l’énorme quantité de calculs intermédiaires qui doivent être rappelés à chaque étape de traitement ultérieure. Ces résultats intermédiaires sont plutôt stockés dans la mémoire à large bande passante (HBM), qui ne se trouve pas sur la puce GPU elle-même. Cela entraîne une courte latence chaque fois que les clés et les valeurs doivent être lues à partir de la mémoire. Avec le dimensionnement des modèles de transformeur à plusieurs milliards de paramètres, le temps et les capacités de calcul nécessaires pour entraîner et exécuter l’inférence sont devenus un obstacle à la performance des modèles.
La poursuite de leur évolution nécessitait des méthodes pour réduire le nombre d’étapes de calcul sans réduire la capacité des transformeurs à apprendre et à reproduire des schémas linguistiques complexes. C’est dans ce contexte que la MQA, et par la suite la GQA, ont vu le jour.
L’attention multirequête (MQA) est un mécanisme d’attention plus efficace en matière de calcul, qui simplifie l’attention multitête pour réduire l’utilisation de la mémoire et les calculs intermédiaires. Au lieu d’entraîner une clé de clé et une tête de valeur uniques pour chaque tête d’attention, la MQA utilise une seule tête de clé et une seule tête de valeur à chaque couche. Par conséquent, les vecteurs de clé et les vecteurs de valeur ne sont calculés qu’une seule fois ; cet ensemble unique de vecteurs de clé et de valeur est ensuite partagé entre les h têtes d’attention.
Cette simplification réduit considérablement le nombre de projections linéaires que le modèle doit calculer et stocker dans la mémoire à large bande passante. Selon l’article de 2019 qui a présenté le MQA, le MQA permet un stockage de paire clé-valeur 10 à 100 fois plus petit (ou cache KV) et une inférence de décodeur 12 fois plus rapide. L’utilisation réduite de la mémoire de MQA accélère également considérablement l’entraînement en augmentant la taille des lots.
Outre ses avantages, la MQA présente quelques inconvénients inévitables.
L’attention de requête groupée est une formulation plus générale et plus flexible de l’attention multirequête, qui partitionne les têtes de requête en plusieurs groupes partageant chacun un ensemble de clés et de valeurs, au lieu de partager un ensemble de clés et de valeurs entre toutes les têtes de requête.
Après la publication de « GQA: Training Generalized Multi-Query Transformer Models from Multi-Head Checkpoints » (GQA : entraîner des modèles de transformation multi-requêtes généralisés à partir de points de contrôle multi-têtes) en mai 2023, de nombreux LLM ont rapidement adopté la GQA. Par exemple, Meta a d’abord adopté la GQA pour ses modèles Llama 2 en juillet 2023, puis l’a conservée dans ses modèles Llama 3 sortis en 2024. Mistral AI a utilisé la GQA dans son modèle Mistral 7B lancé en septembre 2023. De même, les modèles Granite 3.0 d’IBM utilisent la GQA pour une inférence rapide.
En théorie, le GQA peut être considéré comme une généralisation du spectre entre le MHA standard et le MQA complet. Le GQA avec le même nombre de groupes de têtes clé-valeur que de têtes d’attention est l’équivalent du MHA standard ; le GQA avec un groupe de têtes est l’équivalent du MQA.
En pratique, GQA implique presque toujours une approche intermédiaire, dans laquelle le nombre de groupes est en lui-même un hyperparamètre important.
L’attention par requêtes groupées offre plusieurs avantages qui ont conduit à son adoption relativement répandue par les principaux LLM.
Entraînez, validez, réglez et déployez une IA générative, des modèles de fondation et des capacités de machine learning avec IBM watsonx.ai, un studio d’entreprise nouvelle génération pour les générateurs d’IA. Créez des applications d’IA en peu de temps et avec moins de données.
Mettez l’IA au service de votre entreprise en vous appuyant sur l’expertise de pointe d’IBM dans le domaine de l’IA et sur son portefeuille de solutions.
Réinventez les workflows et les opérations critiques en ajoutant l’IA pour optimiser les expériences, la prise de décision et la valeur métier en temps réel.