single
Uitdaging: Het Plotten van Aandacht-Heatmaps
Veeg om het menu te tonen
Het visualiseren van aandachtgewichten met een heatmap helpt bij het interpreteren van hoe een transformermodel zijn focus over een zin verdeelt. Met matplotlib wordt een heatmap geplot waarbij zowel de x-as als de y-as de tokens uit de zin weergeven. Elke cel in de heatmap toont het aandachtgewicht tussen een paar tokens: de rij komt overeen met het query-token en de kolom met het key-token.
Begin met het splitsen van de zin in tokens:
sentence = "Transformers help models focus on important words."
tokens = sentence.split()
Definieer vervolgens de attentiematrix als een NumPy-array. Elke waarde vertegenwoordigt het aandachtgewicht van het ene token naar het andere:
import numpy as np
attention = np.array([
[0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20],
[0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10],
[0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10],
[0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15],
[0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15],
[0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10],
[0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25],
])
Gebruik de volgende code om de heatmap te plotten:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(attention, cmap="viridis")
ax.set_xticks(np.arange(len(tokens)))
ax.set_yticks(np.arange(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_yticklabels(tokens)
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Attention Weight")
ax.set_title("Attention Heatmap")
ax.set_xlabel("Key Tokens")
ax.set_ylabel("Query Tokens")
plt.tight_layout()
plt.show()
Fellere of donkerdere kleuren geven hogere of lagere aandachtwaarden aan, afhankelijk van de colormap. In de heatmap is te zien op welke woorden het model het meest let bij het verwerken van elk token. Als bijvoorbeeld de cel op rij focus en kolom important fel is, legt het model een sterke verbinding tussen focus en important in zijn interne representatie. Deze visualisatie helpt te begrijpen welke delen van de invoerzin elkaar beïnvloeden en is nuttig voor het diagnosticeren of interpreteren van modelgedrag bij natural language processing-taken.
Voer nu de code uit om de resulterende heatmap te bekijken en schrijf vervolgens je eerste visualisatieplot.
123456789101112131415161718192021222324252627282930313233import numpy as np import matplotlib.pyplot as plt sentence = "Transformers help models focus on important words." tokens = sentence.split() attention = np.array([ [0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20], [0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10], [0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10], [0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15], [0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15], [0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10], [0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25], ]) fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(attention, cmap="viridis") ax.set_xticks(np.arange(len(tokens))) ax.set_yticks(np.arange(len(tokens))) ax.set_xticklabels(tokens, rotation=45, ha="right") ax.set_yticklabels(tokens) cbar = plt.colorbar(im, ax=ax) cbar.set_label("Attention Weight") ax.set_title("Attention Heatmap") ax.set_xlabel("Key Tokens") ax.set_ylabel("Query Tokens") plt.tight_layout() plt.show()
Veeg om te beginnen met coderen
Plot een aandacht-heatmap voor de zin "Attention helps models understand context." met behulp van de volgende aandachtmatrix:
attention = [
[0.25, 0.15, 0.20, 0.20, 0.20],
[0.10, 0.40, 0.15, 0.20, 0.15],
[0.15, 0.10, 0.35, 0.20, 0.20],
[0.20, 0.15, 0.20, 0.25, 0.20],
[0.15, 0.20, 0.20, 0.20, 0.25],
]
- Gebruik
matplotlibom een heatmap te maken; - Label beide assen met de tokens van de zin;
- Voeg een kleurenschaal toe met het label "Attention Weight;"
- Geef de plot de titel "Attention Heatmap."
Oplossing
Bedankt voor je feedback!
single
Vraag AI
Vraag AI
Vraag wat u wilt of probeer een van de voorgestelde vragen om onze chat te starten.