セクション 3. 章 4
single
チャレンジ:アテンションヒートマップのプロット
メニューを表示するにはスワイプしてください
アテンション重みをヒートマップで可視化することで、トランスフォーマーモデルが文中のどこに注目しているかを解釈しやすくなります。matplotlib を使用して、x軸とy軸の両方に文のトークンを配置したヒートマップを描画します。各セルはトークンのペア間のアテンション重みを示し、行はクエリトークン、列はキートークンに対応します。
まず、文をトークンに分割します:
sentence = "Transformers help models focus on important words."
tokens = sentence.split()
次に、アテンション行列をNumPy配列として定義します。各値は、あるトークンから別のトークンへのアテンション重みを表します:
import numpy as np
attention = np.array([
[0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20],
[0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10],
[0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10],
[0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15],
[0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15],
[0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10],
[0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25],
])
ヒートマップを描画するには、次のコードを使用します:
import matplotlib.pyplot as plt
fig, ax = plt.subplots(figsize=(8, 6))
im = ax.imshow(attention, cmap="viridis")
ax.set_xticks(np.arange(len(tokens)))
ax.set_yticks(np.arange(len(tokens)))
ax.set_xticklabels(tokens, rotation=45, ha="right")
ax.set_yticklabels(tokens)
cbar = plt.colorbar(im, ax=ax)
cbar.set_label("Attention Weight")
ax.set_title("Attention Heatmap")
ax.set_xlabel("Key Tokens")
ax.set_ylabel("Query Tokens")
plt.tight_layout()
plt.show()
カラーマップによって明るい色や暗い色は高いまたは低いアテンション値を示します。ヒートマップを見ることで、各トークンを処理する際にモデルがどの単語に最も注目しているかを確認できます。例えば、focus 行と important 列のセルが明るい場合、モデルは内部表現で focus と important を強く関連付けています。この可視化は、入力文のどの部分が互いに影響し合っているかを理解するのに役立ち、自然言語処理タスクにおけるモデルの挙動の診断や解釈にも有用です。
コードを実行してヒートマップを確認し、最初の可視化プロットを作成してください。
123456789101112131415161718192021222324252627282930313233import numpy as np import matplotlib.pyplot as plt sentence = "Transformers help models focus on important words." tokens = sentence.split() attention = np.array([ [0.20, 0.10, 0.05, 0.10, 0.25, 0.10, 0.20], [0.05, 0.30, 0.10, 0.10, 0.15, 0.20, 0.10], [0.10, 0.15, 0.35, 0.10, 0.10, 0.10, 0.10], [0.10, 0.10, 0.10, 0.30, 0.10, 0.15, 0.15], [0.15, 0.10, 0.10, 0.10, 0.30, 0.10, 0.15], [0.10, 0.10, 0.15, 0.10, 0.10, 0.35, 0.10], [0.20, 0.15, 0.10, 0.10, 0.10, 0.10, 0.25], ]) fig, ax = plt.subplots(figsize=(8, 6)) im = ax.imshow(attention, cmap="viridis") ax.set_xticks(np.arange(len(tokens))) ax.set_yticks(np.arange(len(tokens))) ax.set_xticklabels(tokens, rotation=45, ha="right") ax.set_yticklabels(tokens) cbar = plt.colorbar(im, ax=ax) cbar.set_label("Attention Weight") ax.set_title("Attention Heatmap") ax.set_xlabel("Key Tokens") ax.set_ylabel("Query Tokens") plt.tight_layout() plt.show()
タスク
スワイプしてコーディングを開始
次のアテンション行列を使用して、文「Attention helps models understand context.」のアテンションヒートマップをプロットしてください。
attention = [
[0.25, 0.15, 0.20, 0.20, 0.20],
[0.10, 0.40, 0.15, 0.20, 0.15],
[0.15, 0.10, 0.35, 0.20, 0.20],
[0.20, 0.15, 0.20, 0.25, 0.20],
[0.15, 0.20, 0.20, 0.20, 0.25],
]
matplotlibを使用してヒートマップを作成- 両方の軸に文のトークンをラベル付け
- カラーバーに「Attention Weight」とラベル付け
- プロットのタイトルを「Attention Heatmap」とする
解答
すべて明確でしたか?
フィードバックありがとうございます!
セクション 3. 章 4
single
AIに質問する
AIに質問する
何でも質問するか、提案された質問の1つを試してチャットを始めてください