Notice: This page requires JavaScript to function properly.
Please enable JavaScript in your browser settings or update your browser.
Lernen Herausforderung: Plotten von Attention-Heatmaps | Anwendung von Transformern auf NLP-Aufgaben
Transformer für Natural Language Processing
Abschnitt 3. Kapitel 4
single

single

bookHerausforderung: Plotten von Attention-Heatmaps

Swipe um das Menü anzuzeigen

Die Visualisierung von Attention-Gewichten mit einer Heatmap unterstützt die Interpretation, wie ein Transformer-Modell seine Aufmerksamkeit über einen Satz verteilt. Mit matplotlib wird eine Heatmap erstellt, bei der sowohl die x- als auch die y-Achse die Token des Satzes darstellen. Jede Zelle in der Heatmap zeigt das Attention-Gewicht zwischen einem Token-Paar: Die Zeile entspricht dem Query-Token, die Spalte dem Key-Token.

Beginne damit, den Satz in Token zu unterteilen:

sentence = "Transformers help models focus on important words."
tokens = sentence.split()

Als Nächstes wird die Attention-Matrix als NumPy-Array definiert. Jeder Wert repräsentiert das Attention-Gewicht von einem Token zum anderen:

import numpy as np
attention = np.array([
    [0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20],
    [0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10],
    [0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10],
    [0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15],
    [0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15],
    [0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10],
    [0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25],
])

Zum Plotten der Heatmap wird folgender Code verwendet:

import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(attention, cmap="viridis")

ax.set_xticks(np.arange(len(tokens)))
ax.set_yticks(np.arange(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_yticklabels(tokens)

cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Attention Weight")

ax.set_title("Attention Heatmap")
ax.set_xlabel("Key Tokens")
ax.set_ylabel("Query Tokens")

plt.tight_layout()
plt.show()

Helle oder dunkle Farben zeigen je nach Farbschema höhere oder niedrigere Attention-Werte an. Beim Betrachten der Heatmap wird ersichtlich, auf welche Wörter das Modell bei der Verarbeitung jedes Tokens besonders achtet. Ist beispielsweise die Zelle in der Zeile focus und der Spalte important hell, besteht eine starke Verbindung zwischen focus und important in der internen Repräsentation des Modells. Diese Visualisierung erleichtert das Verständnis, welche Teile des Eingabesatzes sich gegenseitig beeinflussen, und ist hilfreich zur Diagnose oder Interpretation des Modellverhaltens bei Aufgaben der natürlichen Sprachverarbeitung.

Führe nun den Code aus, um die resultierende Heatmap zu sehen, und erstelle anschließend dein erstes Visualisierungsdiagramm.

123456789101112131415161718192021222324252627282930313233
import numpy as np import matplotlib.pyplot as plt sentence = "Transformers help models focus on important words." tokens = sentence.split() attention = np.array([ [0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20], [0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10], [0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10], [0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15], [0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15], [0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10], [0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25], ]) fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(attention, cmap="viridis") ax.set_xticks(np.arange(len(tokens))) ax.set_yticks(np.arange(len(tokens))) ax.set_xticklabels(tokens, rotation=45, ha="right") ax.set_yticklabels(tokens) cbar = plt.colorbar(im, ax=ax) cbar.set_label("Attention Weight") ax.set_title("Attention Heatmap") ax.set_xlabel("Key Tokens") ax.set_ylabel("Query Tokens") plt.tight_layout() plt.show()
copy
Aufgabe

Wischen, um mit dem Codieren zu beginnen

Erstellung eines Attention-Heatmaps für den Satz "Attention helps models understand context." unter Verwendung der folgenden Attention-Matrix:

attention = [
    [0.25, 0.15, 0.20, 0.20, 0.20],
    [0.10, 0.40, 0.15, 0.20, 0.15],
    [0.15, 0.10, 0.35, 0.20, 0.20],
    [0.20, 0.15, 0.20, 0.25, 0.20],
    [0.15, 0.20, 0.20, 0.20, 0.25],
]
  • Verwendung von matplotlib zur Erstellung eines Heatmaps;
  • Beschriftung beider Achsen mit den Token des Satzes;
  • Hinzufügen einer Farbskala mit der Beschriftung "Attention Weight;"
  • Titel des Plots: "Attention Heatmap."

Lösung

Switch to desktopWechseln Sie zum Desktop, um in der realen Welt zu übenFahren Sie dort fort, wo Sie sind, indem Sie eine der folgenden Optionen verwenden
War alles klar?

Wie können wir es verbessern?

Danke für Ihr Feedback!

Abschnitt 3. Kapitel 4
single

single

Fragen Sie AI

expand

Fragen Sie AI

ChatGPT

Fragen Sie alles oder probieren Sie eine der vorgeschlagenen Fragen, um unser Gespräch zu beginnen

some-alt