How to Implement Multi-Head Attention
Swipe to show menu
12345678910111213141516171819202122232425262728293031323334353637383940import numpy as np # Toy input: 3 tokens, embedding size 6 x = np.array([ [0.1, 0.2, 0.3, 0.4, 0.5, 0.6], [0.6, 0.5, 0.4, 0.3, 0.2, 0.1], [0.2, 0.1, 0.4, 0.3, 0.6, 0.5], ]) # Multi-head attention parameters num_heads = 2 embed_size = x.shape[1] head_dim = embed_size // num_heads # Random weights for queries, keys, and values for each head np.random.seed(42) W_q = np.random.randn(num_heads, embed_size, head_dim) W_k = np.random.randn(num_heads, embed_size, head_dim) W_v = np.random.randn(num_heads, embed_size, head_dim) attention_outputs = [] for head in range(num_heads): # Linear projections Q = x @ W_q[head] K = x @ W_k[head] V = x @ W_v[head] # Scaled dot-product attention scores = Q @ K.T / np.sqrt(head_dim) attention_weights = np.exp(scores) / np.sum(np.exp(scores), axis=-1, keepdims=True) head_output = attention_weights @ V attention_outputs.append(head_output) # Concatenate outputs from all heads multi_head_output = np.concatenate(attention_outputs, axis=-1) print("Output from head 0:\n", attention_outputs[0]) print("Output from head 1:\n", attention_outputs[1]) print("Concatenated multi-head output:\n", multi_head_output)
Multi-head attention works by applying the same self-attention mechanism independently across several heads in parallel. Each head uses its own set of weight matrices for queries, keys, and values (W_q, W_k, and W_v), projecting the input into lower-dimensional subspaces. The input x in this example contains three tokens, each with an embedding size of 6, and is split into two heads.
For each head, the code repeats the standard self-attention steps: it computes query (Q), key (K), and value (V) matrices, calculates attention scores by taking the dot product of Q and K.T (scaled appropriately), and uses a softmax function to obtain attention weights. These weights are then used to combine the value vectors, producing each head's output.
There is no new mechanism here, multi-head attention is simply the self-attention process performed in parallel for each head. After all heads are processed, their outputs are concatenated along the last dimension. This approach enables the model to capture diverse relationships in the data, as each head can focus on different aspects of the input. The printed results illustrate how running the same self-attention mechanism multiple times and combining the results provides a richer representation for each token.
Thanks for your feedback!
Ask AI
Ask AI
Ask anything or try one of the suggested questions to begin our chat