Attention Rollout Visualization

October 24, 2023
This article studies attention rollout and attention map in vision transformer paper.

This article explores D.8 section of ViT paper (opens in a new tab) where we can find very interesting figures about attention maps. Based on the paper Quantifying attention flow in transformers (opens in a new tab) [1], and ViT-pytorch (opens in a new tab) [2], we will grab how the images were calculated.

TL;DR

  • Attention weights (self-attention) correspond to information transfer between tokens in the same layer.
  • Attention rollout extends this concept to different layers: how much information is transferred from a token in layer l1l_1 to another token in layer l2l_2?
  • Heatmap you have seen in ViT paper corresponds to the information transfer from the embedded patch layer to the class token in the final layer: how much is each patch used for class inference?

1. Recalling Self-Attention

Information Exchange in the Same Layer

In transformers, tokens (patch in ViT) of an input sequence exchange information in the self-attention layer ll (refer this colab (opens in a new tab) for interactive deformation ).

For an iith token zlβˆ’1iz_{l-1}^i coming from layer lβˆ’1l-1 (notation borrowed from ViT paper), an attention head in layer ll first extracts query qliq_l^i, key klik_l^i, and value vliv_l^i from a linear projection of learnable weights. The iith token is associated with a vector in the value space (vl0,vl1,...,vlN)(v_l^0, v_l^1,...,v_l^N) with attention weights proportional to wlij=qli(klj)Tw_l^{ij} = q_l^i(k_l^j)^T (I will omit softmax and scaling for brevity). This attention weight wlijw_l^{ij} can be seen as an information exchange between ii and jj tokens in a head of attention layer ll.

Not Eligible Accumulation of Information across Layers

According to the authors of [1], wlijw_l^{ij} does not seem to accumulate the past information until the layer lβˆ’1l-1, although the weight accounts for the exchange in the same layer.

There are multiple ways to have an intuition on this. First of all, attention weights wlijw_l^{ij} tend to get smaller for upper layers (larger ll) as the below figure shows (bought from [1]).

The y-axis denotes layer index ll and x-axis is tokens. Color denotes the attention weight wl0jw_l^{0j} where 00 is class token.
For layers 4,5, and 6, the magnitudes of attention weights are flat across tokens: no useful information.

Another more statistical approach is investigating the correlation between 1) information score vs 2) attention weights in the final layer LL, wL0jw_L^{0j}.

  1. information score of a token: a drop in inference performance when a token ii is blank-out across layers (e.g., classification performance drop will be large if we blank out mouth patches for a network of cats vs dogs).
  2. wL0iw_L^{0i}: information from the iith word or image token into classification token. This token is fed into the classification head (MLP) in ViT.

Let us assume that: if attention weight wL0iw_L^{0i} has well-accumulated information from token ii in the first layer until the final layer. Then, there should be a high correlation between the information score and the magnitude wL0iw_L^{0i}. But this is not true if you look into Table 1 of [1].

In conclusion, although wlijw_l^{ij} can capture information exchange between ii and jj tokens in the same layer ll, it is not an eligible measure accumulated information transfer of jj token until llth layer to ii token.

2. Attention Rollout

Information across Single Layer

Then, what is a good metric to measure accumulated information transfer from a token ii in layer l1l_1 to another token jj in layer l2l_2? (Good means high correlation in the previous section). Paper [1] presented two measures: information rollout and information flow. In this article, I will examine the former only.

First of all, let us collectively denote attention weights wlijw_l^{ij} at layer ll into matrix Watt(l)=(wlij)W_{att}(l) = (w_l^{ij}) where 0<=i,j<=N0<= i,j <= N ( ii = row, jj = col). The authors propose to encode information exchange through the single layer ll composed of self-attention and the residual connection can be encoded with A(l,lβˆ’1)=Watt(l)+IA (l,l-1) = W_{att} (l) + I.

Information across Multiple Layer

What if we proceed with additional layer l+1l+1? The (r,c) element of A(l,lβˆ’1)A (l,l-1) represents information transfer from token cc to token rr in layer ll. Across two layers ll and l+1l+1, any pair of tokens has multiple paths. For example, in the below figure, zlβˆ’11z_{l-1}^{1} can transfer information into zl+12z_{l+1}^{2} along all tokens zlβˆ—z_{l}^{*}.

Thus, we might associate (2,1)(2,1) element of product A(l+1,lβˆ’1)=A(l+1,l)A(l,lβˆ’1)A(l+1,l-1) = A(l+1,l)A(l,l-1) with information transfer from zlβˆ’11z_{l-1}^{1} to zl+12z_{l+1}^{2}. The same chain is applied for any two tokens in different layers. This is called attention rollout. According to the paper [1], this measure gives a high correlation with information score as shown in Table 1.

Disclaimer

Kindly note that I skipped normalization of matrix for the clarity of explanation. Also, I did not consider the presence of multiple heads. This is implementation-dependent as this repo shows (opens in a new tab). We can take mean along the head or keep only heads with strong attention.

3. Interpreting Heatmap Visualization

Now, one step until D.8 section of ViT paper! Our interest is to measure how much information used for classification from an input patch. As we discussed, this can be derived from attention rollout A(L,0)=A(L,Lβˆ’1)A(Lβˆ’1,Lβˆ’2)...A(1,0)A(L,0) = A(L,L-1) A(L-1,L-2) ... A(1,0). Then, what is the meaning of 00th row of A(L,0)A(L,0)? It is accumulated information into the classification token from an input patch.

As this figure illustrates, the heatmap value of the first patch can be considered to be a measure how much information was used to infer the class of the image. Of course, the size of a patch is not 1x1, thus, it will be very coarse. In general, many codes use interpolation to fit the patch heatmap into the original image.

My explanation so far is boiled down to this notebook (opens in a new tab) written by jeonsworld.

Please leave comments if you found something weird and incorrect. Hope this article helped you understand the frequent heatmap in the context of transformers!