A2S2KResNet——论文概述与代码分析

Posted by Linexus on Sunday, January 22, 2023

A2S2K-ResNet的论文分析

提出背景

为了使神经元有效地调整感受野大小和跨通道依赖性,本文提出了基于注意力的自适应频谱空间核改进的剩余网络,其中的改进点有几个:

  • attention-based adaptive spectral–spatial kernel
  • improved spectral–spatial ResNet

设计较新的模型结构

  • EFR module
  • A2S2K-Net’s block

网络总体结构

网络总体结构图

image-20230122223532018

分为A2S2K-block、ResNet block、Pool Layer、输出层。将重点阐述前两个模块。

A2S2K-block

image-20230122224459860

来自于Selective kernel networks,及通过attention操作将不同卷积核提取的特征进行自适应选取。其网络结构涉及到三个部分:

  • Split:生成具有各种内核大小的多个路径,这些大小对应于神经元的不同感受野大小。
  • Fuse:组合并汇总来自多个路径的信息,以获得选择权重的全局和全面表示。
  • Select:根据选择权重聚合大小不同的内核的特征图。(及相加操作)

代码片段如下:

# 上层卷积
x_1x1 = self.conv1x1(X)
x_1x1 = self.batch_norm1x1(x_1x1).unsqueeze(dim=1)
# 下层卷积
x_3x3 = self.conv3x3(X)
x_3x3 = self.batch_norm3x3(x_3x3).unsqueeze(dim=1)
# concat操作
x1 = torch.cat([x_3x3, x_1x1], dim=1)
# 加和(初步融合特征)
U = torch.sum(x1, dim=1)
# 全局池化操作,为了将特征进行融合和便于之后的attention操作
S = self.pool(U)
# 通过3D卷积进行特征融合
Z = self.conv_se(S)
attention_vector = torch.cat(
    [
        # 通过3D卷积进行扩张
        self.conv_ex(Z).unsqueeze(dim=1),
        self.conv_ex(Z).unsqueeze(dim=1)
    ],
    dim=1)
# softmax是为了将空间和维度的比例进行统一划分
attention_vector = self.softmax(attention_vector)
# 所谓乘上attention后的结果进行权重自适应调整
V = (x1 * attention_vector).sum(dim=1)

EFR Module

image-20230122231258793

就是特征重提取,本模块来自于SENet中的SEA模块,为了在之后能够对卷积后相对重要的特征进行进一步的重提取,可以有效减少网络层数,提升速度。(SENet也是SKNet的主要参考论文)。

改进残差网络

image-20230122232550024

原论文没有提及为什么这么改,我觉得就是为了让3D卷积后面接一个BN和激活函数,让层之间保持这样跨层关系?

代码

网络代码

A2S2K-blockclass S3KAIResNet(nn.Module):
    def __init__(self, band, classes, reduction):
        super(S3KAIResNet, self).__init__()
        self.name = 'SSRN'
        self.conv1x1 = nn.Conv3d(
            in_channels=1,
            out_channels=PARAM_KERNEL_SIZE,
            kernel_size=(1, 1, 7),
            stride=(1, 1, 2),
            padding=0)
        self.conv3x3 = nn.Conv3d(
            in_channels=1,
            out_channels=PARAM_KERNEL_SIZE,
            kernel_size=(3, 3, 7),
            stride=(1, 1, 2),
            padding=(1, 1, 0))

        self.batch_norm1x1 = nn.Sequential(
            nn.BatchNorm3d(
                PARAM_KERNEL_SIZE, eps=0.001, momentum=0.1,
                affine=True),  # 0.1
            nn.ReLU(inplace=True))
        self.batch_norm3x3 = nn.Sequential(
            nn.BatchNorm3d(
                PARAM_KERNEL_SIZE, eps=0.001, momentum=0.1,
                affine=True),  # 0.1
            nn.ReLU(inplace=True))

        self.pool = nn.AdaptiveAvgPool3d(1)
        self.conv_se = nn.Sequential(
            nn.Conv3d(
                PARAM_KERNEL_SIZE, band // reduction, 1, padding=0, bias=True),
            nn.ReLU(inplace=True))
        self.conv_ex = nn.Conv3d(
            band // reduction, PARAM_KERNEL_SIZE, 1, padding=0, bias=True)
        self.softmax = nn.Softmax(dim=1)

        self.res_net1 = Residual(
            PARAM_KERNEL_SIZE,
            PARAM_KERNEL_SIZE, (1, 1, 7), (0, 0, 3),
            start_block=True)
        self.res_net2 = Residual(PARAM_KERNEL_SIZE, PARAM_KERNEL_SIZE,
                                 (1, 1, 7), (0, 0, 3))
        self.res_net3 = Residual(PARAM_KERNEL_SIZE, PARAM_KERNEL_SIZE,
                                 (3, 3, 1), (1, 1, 0))
        self.res_net4 = Residual(
            PARAM_KERNEL_SIZE,
            PARAM_KERNEL_SIZE, (3, 3, 1), (1, 1, 0),
            end_block=True)

        kernel_3d = math.ceil((band - 6) / 2)
        # print(kernel_3d)

        self.conv2 = nn.Conv3d(
            in_channels=PARAM_KERNEL_SIZE,
            out_channels=128,
            padding=(0, 0, 0),
            kernel_size=(1, 1, kernel_3d),
            stride=(1, 1, 1))
        self.batch_norm2 = nn.Sequential(
            nn.BatchNorm3d(128, eps=0.001, momentum=0.1, affine=True),  # 0.1
            nn.ReLU(inplace=True))
        self.conv3 = nn.Conv3d(
            in_channels=1,
            out_channels=PARAM_KERNEL_SIZE,
            padding=(0, 0, 0),
            kernel_size=(3, 3, 128),
            stride=(1, 1, 1))
        self.batch_norm3 = nn.Sequential(
            nn.BatchNorm3d(
                PARAM_KERNEL_SIZE, eps=0.001, momentum=0.1,
                affine=True),  # 0.1
            nn.ReLU(inplace=True))

        self.avg_pooling = nn.AvgPool3d(kernel_size=(5, 5, 1))
        self.full_connection = nn.Sequential(
            nn.Linear(PARAM_KERNEL_SIZE, classes)
            # nn.Softmax()
        )

    def forward(self, X):
        # A2S2K-block
        x_1x1 = self.conv1x1(X)
        x_1x1 = self.batch_norm1x1(x_1x1).unsqueeze(dim=1)
        x_3x3 = self.conv3x3(X)
        x_3x3 = self.batch_norm3x3(x_3x3).unsqueeze(dim=1)

        x1 = torch.cat([x_3x3, x_1x1], dim=1)
        U = torch.sum(x1, dim=1)
        S = self.pool(U)
        Z = self.conv_se(S)
        attention_vector = torch.cat(
            [
                self.conv_ex(Z).unsqueeze(dim=1),
                self.conv_ex(Z).unsqueeze(dim=1)
            ],
            dim=1)
        attention_vector = self.softmax(attention_vector)
        V = (x1 * attention_vector).sum(dim=1)
		
        # res-net block
        # start block
        x2 = self.res_net1(V)
        # middle block
        x2 = self.res_net2(x2)
        # 防止过拟合
        x2 = self.batch_norm2(self.conv2(x2))
        x2 = x2.permute(0, 4, 2, 3, 1)
        x2 = self.batch_norm3(self.conv3(x2))
		# middle block
        x3 = self.res_net3(x2)
        # last block
        x3 = self.res_net4(x3)
        x4 = self.avg_pooling(x3)
        x4 = x4.view(x4.size(0), -1)
        return self.full_connection(x4)


model = S3KAIResNet(BAND, CLASSES_NUM, 2).cuda()

summary(model, input_data=(1, img_rows, img_cols, BAND), verbose=1)

剩下的代码太多了,详情前往代码处查看。