Swin Transformer

一种基于移位窗口的层次化Transformer

基本架构

基本架构
基本架构

首先将输入的RGB图片划分为不重叠的patch(图中大小为443),之后将每一个patch扁平化位移位向量,这样初始化图像转化成了H4×W4×48\frac{H}{4}\times \frac{W}{4}\times 48,之后将patch进一个MLP将其维度变为CC

之后进行patch的合并,每次将2*2单位内的patch进行合并,之后进一个MLP保证patch的大小只扩张两倍
这种架构可以方便的替换其他一些网络中的backbone

而Swin Transformer块是将传统的Multi-head Self Attention替换成了(Shift-)Window MSA,每一Stage中的层数可能不同,由超参数决定

(S)W-MSA

用于解决常规Attention平方复杂度的问题,将图片划分为不重叠的窗口,窗口的大小为M×MM\times M,而其中MM为常数,这样可以将计算的复杂度从O(N2)O(N^{2})降到O(N)O(N),但是这样的话失去了窗口之间的联系,可能导致信息的丢失,于是引入窗口的移位,即SW-MSA

基础版移动窗口
基础版移动窗口

第l个模块是W-MSA的话,第l+1个就会是SW-MSA,会将窗口向左上循环移动(M2,M2)(\lfloor\frac{M}{2}\rfloor, \lfloor\frac{M}{2}\rfloor)位,确保了跨窗口信息可以被保留下来

但是如果按照Fig2中的移动方式,窗口数会从(hM,wM)(\lceil\frac{h}{M}\rceil, \lceil\frac{w}{M}\rceil)变成(hM+1,wM+1)(\lceil\frac{h}{M}\rceil + 1, \lceil\frac{w}{M}\rceil + 1),因此采用下面这种方式:

移动窗口
移动窗口

这种方式不会增加窗口数量,但是每一个窗口可能是由许多子窗口拼接而成的,因此引入masked限制attention的计算范围,也即给每一个窗口一个index,在attention的过程中只留下相同id窗口的计算结果,忽略其他值,这一步需要依靠根据当前窗口的排布来确定mask

在计算attention的过程中,引入相对偏移矩阵BB

相对位置偏移
相对位置偏移

其中QKVRM2×dQKV\in\mathbb{R}^{M^{2}\times d},因此BRM2×M2B\in\mathbb{R}^{M^{2}\times M^{2}},其是从B^R(2M1)2\hat{B}\in\mathbb{R}^{(2M - 1)^{2}}中取值得到的

BB的每一行分别代表了以窗口的第ii个patch为原点时,其他patch针对原点的相对偏移量,并通过偏移数组B^\hat{B}来确定偏移量,详细计算过程参考博客

原版参数设置

相对位置偏移
相对位置偏移