补充1:
局部窗口内的自注意力(W-MSA):
滑动窗口机制(Shifted Window Attention):
计算量的比较:
滑动窗口的好处:
补充2:
输入图像大小(224x224):
Patch Embedding 和 Stride=16:
16x16
的 non-overlapping patches,然后将每个 patch 展平并映射到一个高维的特征空间。Stride=16
的卷积操作,将图像的空间分辨率从 224x224 减少到 14x14。特征图大小(14x14):
Stride=16
的操作后,原始图像被划分为 14x14 个 patch,每个 patch 被视为一个 token。在 Vision Transformer 中,这 14x14 个 token 会形成一个 196 维的 token 序列。Swin Transformer 在一些细节上和 ViT 有所不同:
h=w=56
)到最后的较小特征图。在你提供的图片中,h=w=56
可能指的是在 Swin Transformer 的某个阶段,特征图被处理时的空间分辨率。例如,在较早的阶段,特征图的空间分辨率较高,经过几次降采样后,可能从 224x224 降到 56x56 甚至更低。
因此,特征图的大小 (14x14 或 56x56) 取决于模型的阶段以及具体的网络结构。在 Swin Transformer 中,早期层的特征图可能较大,而后期层的特征图可能较小,这与 Vision Transformer 中固定的 14x14
特征图有所不同。
补充3:
关于正文中h=w=56, m=7 的补充:
在 Swin Transformer 中,h=w=56
和 m=7
是针对特定阶段的特征图大小和窗口大小。这些参数在 Swin Transformer 中是有具体含义的:
h=w=56
的解释4x4
的 patch(相当于 Stride=4
),并将输入图像从原始的 224x224 分辨率降采样到 56x56 的特征图。Stride=4
的操作后,特征图的大小变成 224/4 = 56,既 h=w=56
。m=7
的解释m=7
表示窗口的大小为 7x7
,也就是说,在每一个 7x7
的局部区域内计算自注意力,而不是在整个 56x56
的全局上计算。56x56
)划分为多个 7x7
的窗口,Swin Transformer 可以在保持计算量可控的前提下,捕捉局部的相关性。Swin Transformer 的网络结构通常分为多个阶段,每个阶段的特征图大小和窗口大小可能有所不同:
Stride=4
的 patch embedding 操作,特征图的大小变为 56x56。7x7
的窗口来进行局部自注意力计算。注1:
注2:
class PatchMerging(nn.Module):
r""" Patch Merging Layer.
Args:
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def __init__(self, dim, norm_layer=nn.LayerNorm):
super().__init__()
self.dim = dim
self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
self.norm = norm_layer(4 * dim)
def forward(self, x, H, W):
"""
x: B, H*W, C
"""
B, L, C = x.shape
assert L == H * W, "input feature has wrong size"
x = x.view(B, H, W, C)
# padding
# 如果输入feature map的H,W不是2的整数倍,需要进行padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
# to pad the last 3 dimensions, starting from the last dimension and moving forward.
# (C_front, C_back, W_left, W_right, H_top, H_bottom)
# 注意这里的Tensor通道是[B, H, W, C],所以会和官方文档有些不同
x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
x = torch.cat([x0, x1, x2, x3], -1) # [B, H/2, W/2, 4*C]
x = x.view(B, -1, 4 * C) # [B, H/2*W/2, 4*C]
x = self.norm(x)
x = self.reduction(x) # [B, H/2*W/2, 2*C]
return x
正文:
因篇幅问题不能全部显示,请点此查看更多更全内容