single
Utfordring: Plotting av Oppmerksomhetsvarmekart
Sveip for å vise menyen
Visualisering av oppmerksomhetsvekter med et varmekart hjelper deg å tolke hvordan en transformer-modell fordeler fokuset sitt over en setning. Du bruker matplotlib for å tegne et varmekart der både x-aksen og y-aksen representerer tokenene fra setningen. Hver celle i varmekartet viser oppmerksomhetsvekten mellom et token-par: raden tilsvarer spørringstokenet, og kolonnen tilsvarer nøkkeltokenet.
Start med å splitte setningen din i token:
sentence = "Transformers help models focus on important words."
tokens = sentence.split()
Deretter definerer du oppmerksomhetsmatrisen som et NumPy-array. Hver verdi representerer oppmerksomhetsvekten fra ett token til et annet:
import numpy as np
attention = np.array([
[0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20],
[0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10],
[0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10],
[0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15],
[0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15],
[0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10],
[0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25],
])
For å tegne varmekartet, bruk følgende kode:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(attention, cmap="viridis")
ax.set_xticks(np.arange(len(tokens)))
ax.set_yticks(np.arange(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_yticklabels(tokens)
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Attention Weight")
ax.set_title("Attention Heatmap")
ax.set_xlabel("Key Tokens")
ax.set_ylabel("Query Tokens")
plt.tight_layout()
plt.show()
Sterkere eller svakere farger indikerer høyere eller lavere oppmerksomhetsverdier, avhengig av fargekartet. Når du ser på varmekartet, kan du se hvilke ord modellen legger mest vekt på når den behandler hvert token. For eksempel, hvis cellen på rad focus og kolonne important er lys, kobler modellen sterkt focus til important i sin interne representasjon. Denne visualiseringen hjelper deg å forstå hvilke deler av inngangssetningen som påvirker hverandre, og er nyttig for å diagnostisere eller tolke modellatferd i oppgaver innen naturlig språkprosessering.
Kjør nå koden for å se det resulterende varmekartet, og skriv deretter din første visualiseringsgraf.
123456789101112131415161718192021222324252627282930313233import numpy as np import matplotlib.pyplot as plt sentence = "Transformers help models focus on important words." tokens = sentence.split() attention = np.array([ [0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20], [0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10], [0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10], [0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15], [0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15], [0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10], [0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25], ]) fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(attention, cmap="viridis") ax.set_xticks(np.arange(len(tokens))) ax.set_yticks(np.arange(len(tokens))) ax.set_xticklabels(tokens, rotation=45, ha="right") ax.set_yticklabels(tokens) cbar = plt.colorbar(im, ax=ax) cbar.set_label("Attention Weight") ax.set_title("Attention Heatmap") ax.set_xlabel("Key Tokens") ax.set_ylabel("Query Tokens") plt.tight_layout() plt.show()
Sveip for å begynne å kode
Plott et oppmerksomhetsvarmekart for setningen "Attention helps models understand context." ved å bruke følgende oppmerksomhetsmatrise:
attention = [
[0.25, 0.15, 0.20, 0.20, 0.20],
[0.10, 0.40, 0.15, 0.20, 0.15],
[0.15, 0.10, 0.35, 0.20, 0.20],
[0.20, 0.15, 0.20, 0.25, 0.20],
[0.15, 0.20, 0.20, 0.20, 0.25],
]
- Bruk
matplotlibfor å lage et varmekart; - Merk begge aksene med setningens tokener;
- Legg til en fargeskala merket "Attention Weight;"
- Tittel på diagrammet: "Attention Heatmap."
Løsning
Takk for tilbakemeldingene dine!
single
Spør AI
Spør AI
Spør om hva du vil, eller prøv ett av de foreslåtte spørsmålene for å starte chatten vår