IKH

Multi-Head Attention

In the previous segment, you saw how the attention model calculates the context vector. However, the original paper suggests the usage of eight such attention heads, which it calls multi-head attention. In this segment, you will understand what exactly it is and how it is different from self-attention.

Rather than computing single attention (the weighted sum of values), multi-head attention computes multiple attention weighted sums (hence the name).

In the next video, Ankush will explain the need for multi-head attention.

As explained in the video, there are two major reasons for using multi-head attention:

If we use the 512 dimensions to calculate the context vector, the matrix multiplication between the query, key and value will be of a large order; therefore, the processing will be slower.

The input encodings are passed to eight different self-attention blocks, each of a dimension S *dk, where dk = 64. Using multiple heads heads gives you a rich understanding of the same sequences, as the context vector is calculated in multiple ways.

Each ‘head’ enables the model to capture various aspects of the input and improve its performance. Essentially, multi-head attention is just several attention layers stacked in parallel. Let’s understand the concept of multi-head attention with a real-life example in the next video.

Each ‘multiple-head’ is a linear transformation of the input representation. Using multiple linear transformations, you can extract different features of the input. But how do we calculate the final context vector resulting from the eight context vectors? You will understand this in the next video

The original paper suggests the usage of h = 8 (context vectors/attention heads); however, this is a hyperparameter that can be customised. As explained in the video, we will concatenate the context vectors from all eight heads. This concatenated matrix is multiplied with a weights matrix such that the output dimension is S x d_model(same as the dimension of input embeddings fed to the encoder). The weight matrix also helps to add in all the learnt linear projections from each of the context vectors.

$$//MultiHead(output)=concat\left(z_1,\;z_2,\;z_3,\;z_4,\;z_5,\;z_6,\;z_7\right)W^0=\left(S,d_{model}\right)//$$

Let’s quickly revise the entire process:

  • Calculate Query(Q), Key(K) and Value() by multiplying the input embeddings with their respective matrices Wq, Wk, Wv .
  • Calculate the context vector (z) from each attention head.
  • Concatenate all the context vectors(z1, z2, z3, z4, z5, z6, z7,) and multiply with a weights matrix such that the output(Z) = (S, dmodel ).

When we encode the word “it”, the first attention head captures only one feature of the given input sequences. But if we activate all 8 attention heads, the model captures multiple features of the same input sequence.

Now that you have understood the working of a multi-head attention block, let’s take a look at the next sub-layer of the encoder block: pointwise feed-forward network.

Additional Readings

Visualizing Attention in Transformer

Check out this link to visualise different attention heads in Transformer architecture.

Report an error