ResNet-RS架构复现--CVPR2021

举报
别团等shy哥发育 发表于 2023/01/08 18:11:29 2023/01/08
【摘要】 @toc参考论文:Revisiting ResNets: Improved Training and Scaling Strategies 作者:Irwan Bello, William Fedus, Xianzhi Du, Ekin D. Cubuk, Aravind Srinivas, Tsung-Yi Lin, Jonathon Shlens, Barret Zoph这里主要是架构复现...

@toc

参考论文:Revisiting ResNets: Improved Training and Scaling Strategies

作者:Irwan Bello, William Fedus, Xianzhi Du, Ekin D. Cubuk, Aravind Srinivas, Tsung-Yi Lin, Jonathon Shlens, Barret Zoph

这里主要是架构复现,由于论文中细节太多,原理部分只是不细讲。

1、摘要

  我们的工作重新审视了规范的 ResNet (He et al., 2015),并研究了这三个方面,以试图解开它们。也许令人惊讶的是,我们发现训练和扩展策略可能比架构变化更重要,而且由此产生的 ResNet 与最近最先进的模型相匹配。我们展示了表现最佳的缩放策略取决于训练方案,并提供了两种新的缩放策略:(1)在可能发生过度拟合的情况下缩放模型深度(否则宽度缩放更可取); (2) 提高图像分辨率的速度比之前推荐的要慢(Tan & Le,2019)。使用改进的训练和扩展策略,我们设计了一系列 ResNet 架构 ResNet-RS,它比 TPU 上的 EfficientNets 快 1.7 倍 - 2.7 倍,同时在 ImageNet 上实现了类似的精度。在大规模的半监督学习设置中,ResNet-RS 实现了 86.2% 的 top-1 ImageNet 准确率,同时比 EfficientNetNoisyStudent 快 4.7 倍

2、ResNet-D架构

  这篇论文是在ResNet-D架构基础上改造的。

  在Bag of Tricks for Image Classification with Convolutional Neural Networks 这篇论文中的ResNet-D基础结构如下所示:

image-20220907220047082

  注意,残差边上多了个池化操作。

  ResNet-D (He et al., 2018) 结合了对原始 ResNet 架构的以下四个调整。

  • 首先,在 InceptionV3 (Szegedy et al., 2016) 中首次提出,将干中的 7×7 卷积替换为三个较小的 3×3 卷积。
  • 其次,在下采样块的残差路径中切换前两个卷积的步幅大小。
  • 第三,将下采样块的skip connection路径中的 stride-2 1×1 卷积替换为 stride-2 2×2 平均池化,然后是non-strided 1×1 卷积。
  • 第四,去除 stride2 3×3 max pool layer,下采样发生在下一个bottleneck block的第一个 3×3 卷积中。

image-20220907215752784

3、改进训练方法

  我们在表 1 中展示了关于训练、正则化方法和架构变化的附加研究。基准 ResNet-200 获得了 79.0% 的 top-1 准确率。我们仅通过改进的训练方法将其性能提高到 82.2% (+3.2%),而无需任何架构更改。当添加两个常见且简单的架构更改(Squeeze-and-Excitation 和 ResNet-D)时,我们将性能进一步提高到 83.4%。仅训练方法就导致了总改进的 3/4,这证明了它们对 ImageNet 性能的关键影响。

image-20220907220422333

  表 1. ResNet-RS 训练配方的附加研究。颜色指的是训练方法、正则化方法和架构改进。使用逐步学习率衰减计划对基线 ResNet-200 进行了标准 90 个 epoch 的训练。图像分辨率为 256×256。所有数字都在 ImageNet 验证集上报告,并在 2 次运行中取平均值。 † 仅在使用正则化方法后,将训练持续时间增加到 350 个 epoch 才会有用,否则会由于过度拟合而导致精度下降。

4、改进的缩放策略

  • 在可能发生过拟合的情况下进行深度缩放:对于较长的epoch,深度缩放优于宽度缩放;对于较短的epoch,宽度缩放优于深度缩放。
  • 缓慢的图像分辨率缩放。

image-20220907220756954

   图 3. ResNet 在深度、宽度、图像分辨率和训练时期的缩放。在训练模型 10、100 或 350 个 epoch 时,我们比较了四种不同图像分辨率 [128,160,224,320] 的深度缩放和宽度缩放。我们发现表现最好的缩放策略取决于训练机制,这揭示了从小规模机制推断缩放规则的缺陷。

  (左)10 Epoch Regime:宽度缩放是速度精度帕累托曲线的最佳策略。

  (中)100 Epoch Regime:深度缩放有时优于宽度缩放。

  (右)350 Epoch Regime:深度缩放始终比宽度缩放有很大的优势。即使使用正则化方法,过度拟合仍然是一个问题。模型详细信息:所有模型都从深度 101 开始,并增加到 [101,200,300,400]。所有模型宽度都以 1.0x 的乘数开始,并通过 [1.0,1.5,2.0] 增加。对于所有模型,我们调整正则化以限制过度拟合(参见附录 E)。在 ImageNet minival-set 上报告准确度,在 TPU 上测量训练时间。

5、Appendix

5.1 Pareto 曲线中所有 ResNet-RS 模型的详细信息

  本节详细介绍 ResNet-RS Pareto 曲线中的所有模型。在表 7 中,我们观察到我们的 ResNet-RS 模型在 TPU 上的 EfficientNet Pareto 曲线上获得了 1.7x - 2.7x 的加速。

image-20220907221005384

  表 7. Pareto 曲线中 ResNet-RS 模型的详细信息。使用第 5 节中提到的改进对所有模型进行了 350 个 epoch 的训练。所有 ResNet-RS 模型的确切超参数在表 8 中。Tesla V100 GPU 上的延迟以全精度 (float32) 测量。 TPUv3 上的延迟是使用 bfloat16 精度测量的。所有延迟都是用 128 个图像的初始训练批量大小测量的,该大小除以 2 直到它适合加速器。

5.2 ResNet-RS 架构细节

  我们提供了有关 ResNet-RS 架构更改的更多详细信息。我们重申 ResNet-RS 是:改进的缩放策略、改进的训练方法、ResNet-D 修改(He 等人,2018 年)和 SqueezeExcitation 模块(Hu 等人,2018 年)的组合

  表 11 显示了我们工作中使用的所有 ResNet 深度的块布局。 ResNet-50 到 ResNet-200 使用 He 等人的标准块配置。 (2015 年)。 ResNet-270 及更高版本主要扩展 c3 和 c4 中的块数,我们尝试保持它们的比例大致恒定。我们凭经验发现,在较低阶段添加块会限制过度拟合,因为较低层中的块具有显着较少的参数,即使所有块具有相同数量的 FLOP。图 6 显示了我们的 ResNet-RS 模型中使用的 ResNet-D 架构更改。

image-20220907221339109

   图 6. ResNet-RS 架构图。

  输出大小假定输入图像分辨率为 224×224。

  在卷积布局中,x2 是指第一个 3×3 卷积,步长为 2。

  ResNet-RS 架构是 Squeeze-and-Excitation 和 ResNet-D 的简单组合。

  × 符号表示块在 ResNet-101 架构中重复的次数。这些值根据表 11 中的块布局随深度变化。

5.3 Scaling Analysis Regularization and Model Details

image-20220907221444890

  表 12. Dropout values for filter scaling.。 filter scaling是指基于原始 ResNet 架构中filters数量的filters缩放乘数。

6、ResNet-RS架构搭建

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.models import Model
from typing import Callable, Dict, List, Union

6.1 模型配置项

DEPTH_TO_WEIGHT_VARIANTS = {
    50: [160],
    101: [160, 192],
    152: [192, 224, 256],
    200: [256],
    270: [256],
    350: [256, 320],
    420: [320],
}
BLOCK_ARGS = {
    50: [
        {
            "input_filters": 64,
            "num_repeats": 3
        },
        {
            "input_filters": 128,
            "num_repeats": 4
        },
        {
            "input_filters": 256,
            "num_repeats": 6
        },
        {
            "input_filters": 512,
            "num_repeats": 3
        },
    ],
    101: [
        {
            "input_filters": 64,
            "num_repeats": 3
        },
        {
            "input_filters": 128,
            "num_repeats": 4
        },
        {
            "input_filters": 256,
            "num_repeats": 23
        },
        {
            "input_filters": 512,
            "num_repeats": 3
        },
    ],
    152: [
        {
            "input_filters": 64,
            "num_repeats": 3
        },
        {
            "input_filters": 128,
            "num_repeats": 8
        },
        {
            "input_filters": 256,
            "num_repeats": 36
        },
        {
            "input_filters": 512,
            "num_repeats": 3
        },
    ],
    200: [
        {
            "input_filters": 64,
            "num_repeats": 3
        },
        {
            "input_filters": 128,
            "num_repeats": 24
        },
        {
            "input_filters": 256,
            "num_repeats": 36
        },
        {
            "input_filters": 512,
            "num_repeats": 3
        },
    ],
    270: [
        {
            "input_filters": 64,
            "num_repeats": 4
        },
        {
            "input_filters": 128,
            "num_repeats": 29
        },
        {
            "input_filters": 256,
            "num_repeats": 53
        },
        {
            "input_filters": 512,
            "num_repeats": 4
        },
    ],
    350: [
        {
            "input_filters": 64,
            "num_repeats": 4
        },
        {
            "input_filters": 128,
            "num_repeats": 36
        },
        {
            "input_filters": 256,
            "num_repeats": 72
        },
        {
            "input_filters": 512,
            "num_repeats": 4
        },
    ],
    420: [
        {
            "input_filters": 64,
            "num_repeats": 4
        },
        {
            "input_filters": 128,
            "num_repeats": 44
        },
        {
            "input_filters": 256,
            "num_repeats": 87
        },
        {
            "input_filters": 512,
            "num_repeats": 4
        },
    ],
}
CONV_KERNEL_INITIALIZER = {
    "class_name": "VarianceScaling",
    "config": {
        "scale": 2.0,
        "mode": "fan_out",
        "distribution": "truncated_normal"
    },
}

这里只搭建ResNet-RS101架构

6.2 get_survival_probability

根据区块数和初始速率获取生存概率
def get_survival_probability(init_rate, block_num, total_blocks):
    return init_rate * float(block_num) / total_blocks

6.3 fixed_padding

def fixed_padding(inputs, kernel_size):
    """沿空间维度填充输入,与输入大小无关"""
    pad_total = kernel_size - 1
    pad_beg = pad_total // 2
    pad_end = pad_total - pad_beg

    # 使用 ZeroPadding 来避免 TFOpLambda 层
    padded_inputs = layers.ZeroPadding2D(
        padding=((pad_beg, pad_end), (pad_beg, pad_end)))(inputs)

    return padded_inputs

6.4 Conv2DFixedPadding

# Conv2D block with fixed padding
def Conv2DFixedPadding(filters, kernel_size, strides, name=None):
    def apply(inputs):
        if strides > 1:
            inputs = fixed_padding(inputs, kernel_size)
        return layers.Conv2D(filters=filters,
                             kernel_size=kernel_size,
                             strides=strides,
                             padding='same' if strides == 1 else 'valid',
                             use_bias=False,
                             kernel_initializer=CONV_KERNEL_INITIALIZER,
                             name=name)(inputs)

    return apply

6.5 STEM块

# ResNet-D型STEM块
def STEM(inputs,
         bn_momentum: float = 0.0,
         bn_epsilon: float = 1e-5,
         activation: str = 'relu',
         name=None):
    # first stem block
    x = Conv2DFixedPadding(filters=32,
                           kernel_size=3,
                           strides=2,
                           name=name + '_stem_conv_1')(inputs)
    x = layers.BatchNormalization(momentum=bn_momentum,
                                  epsilon=bn_epsilon,
                                  name=name + '_stem_batch_norm_1')(x)
    x = layers.Activation(activation, name=name + '_stem_act_1')(x)

    # second stem block
    x = Conv2DFixedPadding(filters=32,
                           kernel_size=3,
                           strides=1,
                           name=name + '_stem_conv_2')(x)
    x = layers.BatchNormalization(momentum=bn_momentum,
                                  epsilon=bn_epsilon,
                                  name=name + '_stem_batch_norm_2')(x)
    x = layers.Activation(activation, name=name + '_stem_act_2')(x)

    # final stem block
    x = Conv2DFixedPadding(filters=64,
                           kernel_size=3,
                           strides=1,
                           name=name + '_stem_conv_3')(x)
    x = layers.BatchNormalization(momentum=bn_momentum,
                                  epsilon=bn_epsilon,
                                  name=name + '_stem_batch_norm_3')(x)
    x = layers.Activation(activation, name=name + '_stem_act_3')(x)

    # Replace stem max pool:
    x = Conv2DFixedPadding(filters=64,
                           kernel_size=3,
                           strides=2,
                           name=name + '_stem_conv_4')(x)
    x = layers.BatchNormalization(momentum=bn_momentum,
                                  epsilon=bn_epsilon,
                                  name=name + 'stem_batch_norm_4')(x)
    x = layers.Activation(activation, name=name + '_stem_act_4')(x)

    return x

6.6 SE注意力机制模块

def SE(inputs,
       in_filters: int,
       se_ratio: float = 0.25,
       expand_ratio: int = 1,
       name=None):
    x = layers.GlobalAveragePooling2D(name=name + '_se_squeeze')(inputs)

    se_shape = (1, 1, x.shape[-1])
    x = layers.Reshape(se_shape, name=name + '_se_reshape')(x)

    num_reduced_filters = max(1, int(in_filters * 4 * se_ratio))

    x = layers.Conv2D(filters=num_reduced_filters,
                      kernel_size=(1, 1),
                      strides=[1, 1],
                      kernel_initializer=CONV_KERNEL_INITIALIZER,
                      padding='same',
                      use_bias=False,
                      activation='relu',
                      name=name + '_se_reduce')(x)
    x = layers.Conv2D(filters=4 * in_filters * expand_ratio,  # Expand ratio is 1 by default
                      kernel_size=[1, 1],
                      strides=[1, 1],
                      kernel_initializer=CONV_KERNEL_INITIALIZER,
                      padding='same',
                      use_bias=False,
                      activation='sigmoid',
                      name=name + '_se_expand')(x)
    out = layers.multiply([inputs, x], name=name + '_se_excite')
    return out

6.7 Bottleneck Block

def BottleneckBlock(filters: int,
                    strides: int,
                    use_projection: bool,
                    bn_momentum: float = 0.0,
                    bn_epsilon: float = 1e-5,
                    activation: str = 'relu',
                    se_ratio: float = 0.25,
                    survival_probability: float = 0.8,
                    name=None):
    # 带有BN的残差网络的bottle block变体
    def apply(inputs):

        shortcut = inputs
        # 是否需要projection shortcut
        if use_projection:
            filters_out = filters * 4
            if strides == 2:
                shortcut = layers.AveragePooling2D(pool_size=(2, 2),
                                                   strides=(2, 2),
                                                   padding='same',
                                                   name=name + '_projection_pooling')(inputs)
                shortcut = Conv2DFixedPadding(filters=filters_out,
                                              kernel_size=1,
                                              strides=1,
                                              name=name + '_projection_conv')(shortcut)
            else:
                shortcut = Conv2DFixedPadding(filters=filters_out,
                                              kernel_size=1,
                                              strides=strides,
                                              name=name + '_projection_conv')(inputs)
                shortcut = layers.BatchNormalization(momentum=bn_momentum,
                                                     epsilon=bn_epsilon,
                                                     name=name + '_projection_batch_norm')(shortcut)

        # first conv layer:1x1 conv
        x = Conv2DFixedPadding(filters=filters,
                               kernel_size=1,
                               strides=1,
                               name=name + '_conv_1')(inputs)
        x = layers.BatchNormalization(momentum=bn_momentum,
                                      epsilon=bn_epsilon,
                                      name=name + 'batch_norm_1')(x)
        x = layers.Activation(activation, name=name + '_act_1')(x)

        # second conv layer:3x3 conv
        x = Conv2DFixedPadding(filters=filters,
                               kernel_size=3,
                               strides=strides,
                               name=name + '_conv_2')(x)
        x = layers.BatchNormalization(momentum=bn_momentum,
                                      epsilon=bn_epsilon,
                                      name=name + '_batch_norm_2')(x)
        x = layers.Activation(activation, name=name + '_act_2')(x)

        # third conv layer:1x1 conv
        x = Conv2DFixedPadding(filters=filters * 4,
                               kernel_size=1,
                               strides=1,
                               name=name + '_conv_3')(x)
        x = layers.BatchNormalization(momentum=bn_momentum,
                                      epsilon=bn_epsilon,
                                      name=name + '_batch_norm_3')(x)

        if 0 < se_ratio < 1:
            x = SE(x, filters, se_ratio=se_ratio, name=name + '_se')

        # Drop connect
        if survival_probability:
            x = layers.Dropout(survival_probability,
                               noise_shape=(None, 1, 1, 1),
                               name=name + '_drop')(x)
        x = layers.Add()([x, shortcut])

        return layers.Activation(activation, name=name + '_output_act')(x)

    return apply

6.8 Block Group

def BlockGroup(filters,
               strides,
               num_repeats,  # Block重复次数
               se_ratio: float = 0.25,
               bn_epsilon: float = 1e-5,
               bn_momentum: float = 0.0,
               activation: str = "relu",
               survival_probability: float = 0.8,
               name=None):
    """Create one group of blocks for the ResNet model."""

    def apply(inputs):
        # 只有每个block_group的第一个block块使用projection shortcut和strides
        x = BottleneckBlock(
            filters=filters,
            strides=strides,
            use_projection=True,
            se_ratio=se_ratio,
            bn_epsilon=bn_epsilon,
            bn_momentum=bn_momentum,
            activation=activation,
            survival_probability=survival_probability,
            name=name + "_block_0_",
        )(inputs)

        for i in range(1, num_repeats):
            x = BottleneckBlock(
                filters=filters,
                strides=1,
                use_projection=False,
                se_ratio=se_ratio,
                activation=activation,
                bn_epsilon=bn_epsilon,
                bn_momentum=bn_momentum,
                survival_probability=survival_probability,
                name=name + f"_block_{i}_",
            )(x)
        return x

    return apply

6.9 ResNetRS

# 构建ResNet-RS模型:这里复现ResNet-RS101
def ResNetRS(depth: int,  # ResNet网络的深度,101:[160,192]
             input_shape=None,
             bn_momentum=0.0,  # BN层的动量参数
             bn_epsilon=1e-5,  # BN层的Epsilon参数
             activation: str = 'relu',  # 激活函数
             se_ratio=0.25,  # 挤压和激发曾的比例
             dropout_rate=0.25,  # 最终分类曾之前的dropout
             drop_connect_rate=0.2,  # skip connection的丢失率
             include_top=True,  # 是否在网络顶部包含全连接层
             block_args: List[Dict[str, int]] = None,  # 字典列表,构造块模块的参数
             model_name='resnet-rs',  # 模型的名称
             pooling=None,  # 可选的池化模式
             weights='imagenet',
             input_tensor=None,
             classes=1000,  # 分类数
             classifier_activation: Union[str, Callable] = 'softmax',  # 分类器激活
             include_preprocessing=True):  # 是否包含预处理层(对输入图像通过ImageNet均值和标准差进行归一化):

    img_input = layers.Input(shape=input_shape)
    x = img_input
    inputs = img_input

    # 这里本来有个预处理判断,tensorflow版本太低。
    # if include_preprocessing:
    #     num_channels=input_shape[-1]
    #     if num_channels==3:
    #         # 预处理

    # Build stem
    x = STEM(x, bn_momentum=bn_momentum, bn_epsilon=bn_epsilon, activation=activation, name='stem')

    # Build blocks
    if block_args is None:
        block_args = BLOCK_ARGS[depth]

    for i, args in enumerate(block_args):
        # print(i,args)
        survival_probability = get_survival_probability(init_rate=drop_connect_rate,
                                                        block_num=i + 2,
                                                        total_blocks=len(block_args) + 1)
        # args['input_filters']=[64,128,256,512]
        # 只有第一个BlockGroup的stride=1,后面三个都是stride=2
        x = BlockGroup(filters=args['input_filters'],
                       activation=activation,
                       strides=(1 if i == 0 else 2),
                       num_repeats=args['num_repeats'],
                       se_ratio=se_ratio,
                       bn_momentum=bn_momentum,
                       bn_epsilon=bn_epsilon,
                       survival_probability=survival_probability,
                       name=f"BlockGroup{i + 2}_")(x)
    # Build head:
    if include_top:
        x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
        if dropout_rate > 0:
            x = layers.Dropout(dropout_rate, name='top_dropout')(x)
        x = layers.Dense(classes, activation=classifier_activation, name='predictions')(x)
    else:
        if pooling == 'avg':
            x = layers.GlobalAveragePooling2D(name='avg_pool')(x)
        elif pooling == 'max':
            x = layers.GlobalMaxPooling2D(name='max_pool')(x)

    # Create model
    model = Model(inputs, x, name=model_name)

    return model

6.10 ResNetRS101架构

# build ResNet-RS101 model
def ResNetRS101(include_top=True,
                weights='imagenet',
                classes=1000,
                input_shape=None,
                input_tensor=None,
                pooling=None,
                classifier_activation='softmax',
                include_preprocessing=True):
    return ResNetRS(depth=101,
                    include_top=include_top,
                    drop_connect_rate=0.0,
                    dropout_rate=0.25,
                    weights=weights,
                    classes=classes,
                    input_shape=input_shape,
                    input_tensor=input_tensor,
                    pooling=pooling,
                    classifier_activation=classifier_activation,
                    model_name='resnet-rs-101',
                    include_preprocessing=include_preprocessing)

if __name__ == '__main__':
    model = ResNetRS101(input_shape=(224, 224, 3), classes=1000)
    model.summary()

image-20220907222127622

6.11 模型结构大图

References

Revisiting ResNets: Improved Training and Scaling Strategies

https://github.com/keras-team/keras/blob/07e13740fd181fc3ddec7d9a594d8a08666645f6/keras/applications/resnet_rs.py#L525

Bag of Tricks for Image Classification with Convolutional Neural Networks

https://www.tensorflow.org/api_docs/python/tf/keras/applications/resnet_rs

【版权声明】本文为华为云社区用户原创内容,转载时必须标注文章的来源(华为云社区)、文章链接、文章作者等基本信息, 否则作者和本社区有权追究责任。如果您发现本社区中有涉嫌抄袭的内容,欢迎发送邮件进行举报,并提供相关证据,一经查实,本社区将立刻删除涉嫌侵权内容,举报邮箱: cloudbbs@huaweicloud.com
  • 点赞
  • 收藏
  • 关注作者

评论(0

0/1000
抱歉,系统识别当前为高风险访问,暂不支持该操作

全部回复

上滑加载中

设置昵称

在此一键设置昵称,即可参与社区互动!

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。

*长度不超过10个汉字或20个英文字符,设置后3个月内不可修改。