Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Apprendre Challenge: Tracer des Cartes Thermiques d'Attention | Application des Transformers aux Tâches NLP
Transformers pour le Traitement du Langage Naturel
Section 3. Chapitre 4
single

single

bookChallenge: Tracer des Cartes Thermiques d'Attention

Glissez pour afficher le menu

La visualisation des poids d'attention à l'aide d'une carte thermique (heatmap) permet d'interpréter comment un modèle transformeur répartit son attention sur une phrase. Vous utilisez matplotlib pour tracer une carte thermique où l'axe des abscisses (x) et l'axe des ordonnées (y) représentent les jetons de la phrase. Chaque cellule de la carte thermique indique le poids d'attention entre une paire de jetons : la ligne correspond au jeton de requête (query), et la colonne correspond au jeton clé (key).

Commencez par découper votre phrase en jetons :

sentence = "Transformers help models focus on important words."
tokens = sentence.split()

Ensuite, définissez votre matrice d'attention comme un tableau NumPy. Chaque valeur représente le poids d'attention d'un jeton vers un autre :

import numpy as np
attention = np.array([
    [0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20],
    [0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10],
    [0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10],
    [0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15],
    [0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15],
    [0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10],
    [0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25],
])

Pour tracer la carte thermique, utilisez le code suivant :

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(attention, cmap="viridis")

ax.set_xticks(np.arange(len(tokens)))
ax.set_yticks(np.arange(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_yticklabels(tokens)

cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Attention Weight")

ax.set_title("Attention Heatmap")
ax.set_xlabel("Key Tokens")
ax.set_ylabel("Query Tokens")

plt.tight_layout()
plt.show()

Des couleurs plus claires ou plus foncées indiquent des valeurs d'attention plus élevées ou plus faibles, selon la palette de couleurs utilisée. En observant la carte thermique, vous pouvez voir sur quels mots le modèle porte le plus d'attention lors du traitement de chaque jeton. Par exemple, si la cellule à la ligne focus et à la colonne important est claire, le modèle établit une forte connexion entre focus et important dans sa représentation interne. Cette visualisation permet de comprendre quelles parties de la phrase d'entrée s'influencent mutuellement et s'avère utile pour diagnostiquer ou interpréter le comportement du modèle dans les tâches de traitement du langage naturel.

Exécutez maintenant le code pour visualiser la carte thermique obtenue, puis rédigez votre premier graphique de visualisation.

123456789101112131415161718192021222324252627282930313233
import numpy as np import matplotlib.pyplot as plt sentence = "Transformers help models focus on important words." tokens = sentence.split() attention = np.array([ [0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20], [0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10], [0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10], [0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15], [0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15], [0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10], [0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25], ]) fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(attention, cmap="viridis") ax.set_xticks(np.arange(len(tokens))) ax.set_yticks(np.arange(len(tokens))) ax.set_xticklabels(tokens, rotation=45, ha="right") ax.set_yticklabels(tokens) cbar = plt.colorbar(im, ax=ax) cbar.set_label("Attention Weight") ax.set_title("Attention Heatmap") ax.set_xlabel("Key Tokens") ax.set_ylabel("Query Tokens") plt.tight_layout() plt.show()
copy
Tâche

Glissez pour commencer à coder

Tracer une carte thermique d'attention pour la phrase "Attention helps models understand context." en utilisant la matrice d'attention suivante :

attention = [
    [0.25, 0.15, 0.20, 0.20, 0.20],
    [0.10, 0.40, 0.15, 0.20, 0.15],
    [0.15, 0.10, 0.35, 0.20, 0.20],
    [0.20, 0.15, 0.20, 0.25, 0.20],
    [0.15, 0.20, 0.20, 0.20, 0.25],
]
  • Utiliser matplotlib pour créer une carte thermique ;
  • Étiqueter les deux axes avec les jetons de la phrase ;
  • Ajouter une barre de couleur intitulée "Attention Weight" ;
  • Titrer le graphique "Attention Heatmap".

Solution

Switch to desktopPassez à un bureau pour une pratique réelleContinuez d'où vous êtes en utilisant l'une des options ci-dessous
Tout était clair ?

Comment pouvons-nous l'améliorer ?

Merci pour vos commentaires !

Section 3. Chapitre 4
single

single

Demandez à l'IA

expand

Demandez à l'IA

ChatGPT

Posez n'importe quelle question ou essayez l'une des questions suggérées pour commencer notre discussion

some-alt