Rearrange of einops: Input Patch Flattening in Vision Transformer

September 7, 2023
This short article is to provide you with a comprehensive glance into the five big categories and general concepts of planning methods

The foremost operation of Vision Transformer (ViT) (opens in a new tab) is to split an image into patches as the figure in the paper shows:

If we look into the pytorch implementation (opens in a new tab) of ViT, the seemingly complicated reshaping is handled with only a single line (opens in a new tab):

vit-pytorch/vit_pytorch/vit.py
  self.to_patch_embedding = nn.Sequential(
      Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_height, p2 = patch_width),
      nn.LayerNorm(patch_dim),
      nn.Linear(patch_dim, dim),
      nn.LayerNorm(dim),
  )

This short article focuses on how we can understand this line one by one.

Einops

The pytorch implementation of ViT depends on Einops (opens in a new tab) which is specially designed for tensor operations (including reshape and transpose). This is being maintained by a mathematician, Alex Rogozhnikov (opens in a new tab) and there is a dedicated paper (opens in a new tab) (ICLR 2022) You can read the official document for the basics (opens in a new tab).

Patch Flattening using Einops

So what does the line 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)' mean? Let me explain this by visualizing access flow of tensors. We can understand the line by breaking down into:

  1. Decomposition (opens in a new tab): b c (h p1) (w p2) -> b c h p1 w p2
  2. Transpose (opens in a new tab): b c h p1 w p2 -> b c h w p1 p2
  3. Transpose: b c h w p1 p2 -> b h w p1 p2 c
  4. Composition (opens in a new tab): b h w p1 p2 c -> b (h w) (p1 p2 c)

First of all, our original shape is b (batch) x c (color) x H (image height) x W (image width).

We will have an access flow as the arrows show. Accessing first the batch index, we will access one of the three color channels. Once we get there, we will have mono-scale image of the color. Until so far, the 4D tensor is easy to understand intuitively.

Step 1. Decomposition along Height and Width of Image

Decomposing Height: b c (h p1) W -> b c h p1 W

For a patch height p1, we can divide the image height H into (h x p1). As h comes first than p1, the division will arrange h major as the figure shows:

Accessing until h axis will give us an interpretable image with size p1 x W ((i.e., tensor[0][0][0] is a mono image).

Decomposing Width: b c h p1 (w p2) -> b c h p1 w p2

Now, we divide the width W into w by p2. The new axis w associates the column order of the patch. See the below figure:

The last three dimensions do not represent an image. Thus, accessing until dimension h cannot be visualized.

Step 2. Transpose

Switch Patch Height and Width Index of Patch b c h p1 w p2 -> b c h w p1 p2

We will switch two axes corresponding patch height p1 and width index of patch w (not W).

If we access until w axis, we will reach the 2D image with the patch size. In this figure, the 2x2 image patch.

Color as the Last Axis: b c h w p1 p2 -> b h w p1 p2 c

Now, the color axis is relocated into the last.

Once we access until w (width index of a patch), the three-dimensional tensor will represent a colored image of the patch.

Step 3. Composition for Flattening: b h w p1 p2 c -> b (h w) (p1 p2 c)

The last step is flattening for patch index (height x width) and the colored patch image.

Further Note

You can play with rearrange by yourself in a similar way with:

import torch
from einops import rearrange
 
random_tensor = torch.rand(4, 4)
print(random_tensor)
print(rearrange(random_tensor, '(h p1) (w p2) -> h w p1 p2', p1=2, p2=2))
print(rearrange(random_tensor, '(h p1) (w p2) -> h w (p1 p2)', p1=2, p2=2))
print(rearrange(random_tensor, '(h p1) (w p2) -> (h w) (p1 p2)', p1=2, p2=2))