博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
tensorflow2.2中定义的ResNet和ResneXt中的bottleneck结构
阅读量:4153 次
发布时间:2019-05-25

本文共 7362 字,大约阅读时间需要 24 分钟。

目录

1、论文中提出的ResNet网络结构

tensorflow的Keras高级API中定义了50,101和152层的ResNet和ResNeXt,其中的bottlenect结构的实现在后面介绍。

ResNet论文中提出的50,101和152层结构如下图所示:

在这里插入图片描述
可以发现,ResNet网络结构中的五个stage分别将feature map尺寸减小一半,输入图片尺寸为224,conv5输出的feature map大小为224/2^5=7;50,101和152层的ResNet中相同stage对应的结构单元是相同的,不同的只是堆叠的次数;不同stage中的结构单元分为三个卷积层,其中前两个卷积的卷积核数相同,最后一个卷积层的卷积核数分别是前两个的4倍。
这种结构单元在原文中被称为bottleneck,bottleneck中包含的三个卷积分别为1x1,3x3和1x1,采用1x1卷积的好处有:
(1)与baseblock相比(如下图),bottleneck结构引入了两次ReLU,大大提高了网络结构的非线性;
(2)采用1*1卷积先降维再升维,大大减少了一个结构单元的计算量,提高了网络结构单位计算量的表现力,同时这种升维降维的过程实现了特征的跨通道线性组合,进一步提高了神经网络的整体表现力。
在这里插入图片描述

2、tensorflow中的三种ResNet或ResNeXt结构单元

tensorflow2.2官方代码中ResNet结构单元共有三种,其中第二种是每个stage前先进行BN和activation操作;第三种是ResNeXt网络所采用的的grouped conv。

2.1、第一种结构单元

ResNet中的第一种bottleneck结构单元,结构单元中含有三个卷积结构,每个卷积结构都包括CBR(conv+BN+ReLu),三个卷积结构中卷积操作的卷积核大小分别为11、33和1*1,官方实现代码:

def block1(x, filters, kernel_size=3, stride=1, conv_shortcut=True, name=None):  """A residual block.  Arguments:    x: input tensor.  x为输入张量,可以接上其它残差结构单元的输出。    filters: integer, filters of the bottleneck layer. #卷积核个数    kernel_size: default 3, kernel size of the bottleneck layer. #卷积核大小    stride: default 1, stride of the first layer. #卷积步长    conv_shortcut: default True, use convolution shortcut if True,        otherwise identity shortcut. #布尔型参数,利用该参数判断shortcut是否要进行conv下采样    name: string, block label. #结构单元名字  Returns:    Output tensor for the residual block. #该结构单元的输出  """  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1 #根据数据格式是通道在前还是通道在后给bn_axis赋值  if conv_shortcut:  #先计算shortcut,如果需要卷积下采样,则执行该分支来计算shortcut,shortcut的卷积层标记为该结构单元的第零层卷积:0_conv    shortcut = layers.Conv2D(        4 * filters, 1, strides=stride, name=name + '_0_conv')(x)    shortcut = layers.BatchNormalization(        axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(shortcut)  else:    shortcut = x #如果不需要卷积下采样,shortcut直接等于输入#第一个卷积结构:  x = layers.Conv2D(filters, 1, strides=stride, name=name + '_1_conv')(x)  x = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(x)  x = layers.Activation('relu', name=name + '_1_relu')(x)#第二个卷积结构  x = layers.Conv2D(      filters, kernel_size, padding='SAME', name=name + '_2_conv')(x)  x = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn')(x)  x = layers.Activation('relu', name=name + '_2_relu')(x)#第三个卷积结构  x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x) #卷积核数为4 * filters,与shortcut的  x = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_3_bn')(x)  x = layers.Add(name=name + '_add')([shortcut, x])  x = layers.Activation('relu', name=name + '_out')(x)#add之后再进行激活  return x

以上代码定义的残差单元如下图所示:

在这里插入图片描述
用以下代码定义上图所示网络结构的多个叠加,对于每个stage,含有如上图所示的多个结构单元,并且只有第一个结构单元的shortcut是采用conv下采样:

def stack1(x, filters, blocks, stride1=2, name=None):  """A set of stacked residual blocks.  Arguments:    x: input tensor.    filters: integer, filters of the bottleneck layer in a block.    blocks: integer, blocks in the stacked blocks.    stride1: default 2, stride of the first layer in the first block.    name: string, stack label.  Returns:    Output tensor for the stacked blocks.  """  x = block1(x, filters, stride=stride1, name=name + '_block1')  for i in range(2, blocks + 1):#当blocks为大于等于2的整数时,才进入for循环    x = block1(x, filters, conv_shortcut=False, name=name + '_block' + str(i))  return x

当blocks等于2时,得到的残差网络结构:

在这里插入图片描述

2.2 第二种结构单元

第二种ResNet结构单元定义代码如下所示:

def block2(x, filters, kernel_size=3, stride=1, conv_shortcut=False, name=None):  """A residual block. 第二种结构单元的输入参数与第一种完全相同  Arguments:      x: input tensor.      filters: integer, filters of the bottleneck layer.      kernel_size: default 3, kernel size of the bottleneck layer.      stride: default 1, stride of the first layer.      conv_shortcut: default False, use convolution shortcut if True,        otherwise identity shortcut.      name: string, block label.  Returns:    Output tensor for the residual block.  """  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1  preact = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_preact_bn')(x)  preact = layers.Activation('relu', name=name + '_preact_relu')(preact)  if conv_shortcut:    shortcut = layers.Conv2D(        4 * filters, 1, strides=stride, name=name + '_0_conv')(preact)  else:    shortcut = layers.MaxPooling2D(1, strides=stride)(x) if stride > 1 else x#注意这里与第一种结构单元不同#第一个卷积结构  x = layers.Conv2D(      filters, 1, strides=1, use_bias=False, name=name + '_1_conv')(preact)  x = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(x)  x = layers.Activation('relu', name=name + '_1_relu')(x)#第二个卷积结构  x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=name + '_2_pad')(x)  x = layers.Conv2D(      filters,      kernel_size,      strides=stride,      use_bias=False,      name=name + '_2_conv')(x)  x = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn')(x)  x = layers.Activation('relu', name=name + '_2_relu')(x)#第三个卷积结构  x = layers.Conv2D(4 * filters, 1, name=name + '_3_conv')(x)  x = layers.Add(name=name + '_out')([shortcut, x])  return x

网络结构如下图所示:

在这里插入图片描述

2.3 第三种结构单元

第三种结构单元的卷积操作采用组卷积,组卷积下图所示:

在这里插入图片描述

定义代码如下:

def block3(x,           filters,           kernel_size=3,           stride=1,           groups=32,           conv_shortcut=True,           name=None):  """A residual block.  Arguments:    x: input tensor.    filters: integer, filters of the bottleneck layer.    kernel_size: default 3, kernel size of the bottleneck layer.    stride: default 1, stride of the first layer.    groups: default 32, group size for grouped convolution. #该参数定义的卷积分组数量    conv_shortcut: default True, use convolution shortcut if True,        otherwise identity shortcut.    name: string, block label.  Returns:    Output tensor for the residual block.  """  bn_axis = 3 if backend.image_data_format() == 'channels_last' else 1  if conv_shortcut:    shortcut = layers.Conv2D(        (64 // groups) * filters,        1,        strides=stride,        use_bias=False,        name=name + '_0_conv')(x)    shortcut = layers.BatchNormalization(        axis=bn_axis, epsilon=1.001e-5, name=name + '_0_bn')(shortcut)  else:    shortcut = x  x = layers.Conv2D(filters, 1, use_bias=False, name=name + '_1_conv')(x)  x = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_1_bn')(x)  x = layers.Activation('relu', name=name + '_1_relu')(x)  c = filters // groups  x = layers.ZeroPadding2D(padding=((1, 1), (1, 1)), name=name + '_2_pad')(x)  x = layers.DepthwiseConv2D(      kernel_size,      strides=stride,      depth_multiplier=c,      use_bias=False,      name=name + '_2_conv')(x)  x_shape = backend.int_shape(x)[1:-1]  x = layers.Reshape(x_shape + (groups, c, c))(x)  x = layers.Lambda(      lambda x: sum(x[:, :, :, :, i] for i in range(c)),      name=name + '_2_reduce')(x)  x = layers.Reshape(x_shape + (filters,))(x)  x = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_2_bn')(x)  x = layers.Activation('relu', name=name + '_2_relu')(x)  x = layers.Conv2D(      (64 // groups) * filters, 1, use_bias=False, name=name + '_3_conv')(x)  x = layers.BatchNormalization(      axis=bn_axis, epsilon=1.001e-5, name=name + '_3_bn')(x)  x = layers.Add(name=name + '_add')([shortcut, x])  x = layers.Activation('relu', name=name + '_out')(x)  return x

以上代码定义的网络结构如下图所示:

在这里插入图片描述

转载地址:http://iarti.baihongyu.com/

你可能感兴趣的文章
java杂记
查看>>
RunTime.getRuntime().exec()
查看>>
Oracle 分组排序函数
查看>>
VMware Workstation 14中文破解版下载(附密钥)(笔记)
查看>>
日志框架学习
查看>>
日志框架学习2
查看>>
SVN-无法查看log,提示Want to go offline,时间显示1970问题,error主要是 url中 有一层的中文进行了2次encode
查看>>
NGINX
查看>>
Qt文件夹选择对话框
查看>>
DeepLearning tutorial(7)深度学习框架Keras的使用-进阶
查看>>
第三方SDK:JPush SDK Eclipse
查看>>
第三方开源库:imageLoader的使用
查看>>
Android studio_迁移Eclipse项目到Android studio
查看>>
转载知乎-前端汇总资源
查看>>
JavaScript substr() 方法
查看>>
JavaScript slice() 方法
查看>>
JavaScript substring() 方法
查看>>
HTML 5 新的表单元素 datalist keygen output
查看>>
(转载)正确理解cookie和session机制原理
查看>>
jQuery ajax - ajax() 方法
查看>>