重參數(shù)新方法,ACNet的升級版DBB
來源丨手寫AI
1、開篇小記知識點(diǎn)1:并行多分支結(jié)構(gòu)提取的特征具有更強(qiáng)的表征性;
具體可以回憶參考DenseNet、VOVNet、Res2Net以及PeleeNet(后續(xù)均會有解讀)。知識點(diǎn)2:
并行多分支結(jié)構(gòu)會帶來更大別的計(jì)算開銷;
具體可以參考CSPNet對此的分析。知識點(diǎn)3:
使用 1×3 conv + 3×1 conv + 3×3 conv 代替原本一個(gè)的 3×3 conv的ACNet重參方法是有效的;
具體可以參考ACNet的分析。知識點(diǎn)4:重參有沒有更好的呢?
答:有,DBB可以說就是ACNet v2,全面升級!2、DBB 簡述
Diverse Branch Block是繼ACNet的又一次對網(wǎng)絡(luò)結(jié)構(gòu)重參數(shù)化的探索,即ACNet v2,DBB設(shè)計(jì)了一個(gè)類似Inception的模塊,以多分支的結(jié)構(gòu)豐富卷積塊的特征空間,各分支結(jié)構(gòu)包括平均池化,多尺度卷積等。最后在推理階段前,把多分支結(jié)構(gòu)中進(jìn)行重參數(shù)化,融合成一個(gè)主分支。加快推理速度的同時(shí),順帶提升一下精度!上圖給出了設(shè)計(jì)的DBB結(jié)構(gòu)示意圖。類似Inception,它采用1×1、1×1?K×K、1×1?AVG等組合方式對原始K×K卷積進(jìn)行增強(qiáng)。對于1×1?K×K分支,設(shè)置中間通道數(shù)等于輸入通道數(shù)并將1×1卷積初始化為Identity矩陣;其他分支則采用常規(guī)方式初始化。此外,在每個(gè)卷積后都添加BN層用于提供訓(xùn)練時(shí)的非線性,這對于性能提升很有必要。
3、DBB的實(shí)現(xiàn)以下是 DBB 的Pytorch實(shí)現(xiàn):
import torch
import torch.nn as nn
import torch.nn.functional as F
from dbb_transforms import *
def conv_bn(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1,
padding_mode='zeros'):
conv_layer = nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size,
stride=stride, padding=padding, dilation=dilation, groups=groups,
bias=False, padding_mode=padding_mode)
bn_layer = nn.BatchNorm2d(num_features=out_channels, affine=True)
se = nn.Sequential()
se.add_module('conv', conv_layer)
se.add_module('bn', bn_layer)
return se
class IdentityBasedConv1x1(nn.Conv2d):
def __init__(self, channels, groups=1):
super(IdentityBasedConv1x1, self).__init__(in_channels=channels,
out_channels=channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False)
assert channels % groups == 0
input_dim = channels // groups
id_value = np.zeros((channels, input_dim, 1, 1))
for i in range(channels):
id_value[i, i % input_dim, 0, 0] = 1
self.id_tensor = torch.from_numpy(id_value).type_as(self.weight)
nn.init.zeros_(self.weight)
def forward(self, input):
kernel = self.weight + self.id_tensor.to(self.weight.device)
result = F.conv2d(input,
kernel,
None,
stride=1,
padding=0,
dilation=self.dilation,
groups=self.groups)
return result
def get_actual_kernel(self):
return self.weight + self.id_tensor.to(self.weight.device)
class BNAndPadLayer(nn.Module):
def __init__(self,
pad_pixels,
num_features,
eps=1e-5,
momentum=0.1,
affine=True,
track_running_stats=True):
super(BNAndPadLayer, self).__init__()
self.bn = nn.BatchNorm2d(num_features,
eps,
momentum,
affine,
track_running_stats)
self.pad_pixels = pad_pixels
def forward(self, input):
output = self.bn(input)
if self.pad_pixels > 0:
if self.bn.affine:
pad_values = self.bn.bias.detach() - self.bn.running_mean * self.bn.weight.detach() / torch.sqrt(
self.bn.running_var + self.bn.eps)
else:
pad_values = - self.bn.running_mean / torch.sqrt(self.bn.running_var + self.bn.eps)
output = F.pad(output, [self.pad_pixels] * 4)
pad_values = pad_values.view(1, -1, 1, 1)
output[:, :, 0:self.pad_pixels, :] = pad_values
output[:, :, -self.pad_pixels:, :] = pad_values
output[:, :, :, 0:self.pad_pixels] = pad_values
output[:, :, :, -self.pad_pixels:] = pad_values
return output
@property
def weight(self):
return self.bn.weight
@property
def bias(self):
return self.bn.bias
@property
def running_mean(self):
return self.bn.running_mean
@property
def running_var(self):
return self.bn.running_var
@property
def eps(self):
return self.bn.eps
class DiverseBranchBlock(nn.Module):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
internal_channels_1x1_3x3=None,
deploy=False,
nonlinear=None,
single_init=False):
super(DiverseBranchBlock, self).__init__()
self.deploy = deploy
if nonlinear is None:
self.nonlinear = nn.Identity()
else:
self.nonlinear = nonlinear
self.kernel_size = kernel_size
self.out_channels = out_channels
self.groups = groups
assert padding == kernel_size // 2
if deploy:
self.dbb_reparam = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True)
else:
self.dbb_origin = conv_bn(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups)
self.dbb_avg = nn.Sequential()
if groups < out_channels:
self.dbb_avg.add_module('conv',
nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False))
self.dbb_avg.add_module('bn',
BNAndPadLayer(pad_pixels=padding,
num_features=out_channels))
self.dbb_avg.add_module('avg',
nn.AvgPool2d(kernel_size=kernel_size,
stride=stride,
padding=0))
self.dbb_1x1 = conv_bn(in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
groups=groups)
else:
self.dbb_avg.add_module('avg',
nn.AvgPool2d(kernel_size=kernel_size,
stride=stride,
padding=padding))
self.dbb_avg.add_module('avgbn',
nn.BatchNorm2d(out_channels))
if internal_channels_1x1_3x3 is None:
# For mobilenet, it is better to have 2X internal channels
internal_channels_1x1_3x3 = in_channels if groups < out_channels else 2 * in_channels
self.dbb_1x1_kxk = nn.Sequential()
if internal_channels_1x1_3x3 == in_channels:
self.dbb_1x1_kxk.add_module('idconv1',
IdentityBasedConv1x1(channels=in_channels, groups=groups))
else:
self.dbb_1x1_kxk.add_module('conv1',
nn.Conv2d(in_channels=in_channels,
out_channels=internal_channels_1x1_3x3,
kernel_size=1,
stride=1,
padding=0,
groups=groups,
bias=False))
self.dbb_1x1_kxk.add_module('bn1',
BNAndPadLayer(pad_pixels=padding,
num_features=internal_channels_1x1_3x3,affine=True))
self.dbb_1x1_kxk.add_module('conv2',
nn.Conv2d(in_channels=internal_channels_1x1_3x3,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=0,
groups=groups,
bias=False))
self.dbb_1x1_kxk.add_module('bn2', nn.BatchNorm2d(out_channels))
# The experiments reported in the paper used the default initialization of bn.weight (all as 1).
# But changing the initialization may be useful in some cases.
if single_init:
# Initialize the bn.weight of dbb_origin as 1 and others as 0.
# This is not the default setting.
self.single_init()
def get_equivalent_kernel_bias(self):
k_origin, b_origin = transI_fusebn(self.dbb_origin.conv.weight,
self.dbb_origin.bn)
if hasattr(self, 'dbb_1x1'):
# 按照方式1進(jìn)行conv+bn的融合
k_1x1, b_1x1 = transI_fusebn(self.dbb_1x1.conv.weight,
self.dbb_1x1.bn)
# 按照方式方式6進(jìn)行多尺度卷積的合并
k_1x1 = transVI_multiscale(k_1x1,
self.kernel_size)
else:
k_1x1, b_1x1 = 0, 0
if hasattr(self.dbb_1x1_kxk, 'idconv1'):
k_1x1_kxk_first = self.dbb_1x1_kxk.idconv1.get_actual_kernel()
else:
k_1x1_kxk_first = self.dbb_1x1_kxk.conv1.weight
# 按照方式1進(jìn)行conv+bn的融合
k_1x1_kxk_first, b_1x1_kxk_first = transI_fusebn(k_1x1_kxk_first,
self.dbb_1x1_kxk.bn1)
# 按照方式1進(jìn)行conv+bn的融合
k_1x1_kxk_second, b_1x1_kxk_second = transI_fusebn(self.dbb_1x1_kxk.conv2.weight,
self.dbb_1x1_kxk.bn2)
# 按照方式3進(jìn)行1x1卷積與kxk卷積的合并
k_1x1_kxk_merged, b_1x1_kxk_merged = transIII_1x1_kxk(k_1x1_kxk_first,
b_1x1_kxk_first,
k_1x1_kxk_second,
b_1x1_kxk_second,
groups=self.groups)
k_avg = transV_avg(self.out_channels, self.kernel_size, self.groups)
# 按照方式1進(jìn)行conv+bn的融合
k_1x1_avg_second, b_1x1_avg_second = transI_fusebn(k_avg.to(self.dbb_avg.avgbn.weight.device),
self.dbb_avg.avgbn)
if hasattr(self.dbb_avg, 'conv'):
# 按照方式1進(jìn)行conv+bn的融合
k_1x1_avg_first, b_1x1_avg_first = transI_fusebn(self.dbb_avg.conv.weight,
self.dbb_avg.bn)
# 按照方式3進(jìn)行1x1卷積與kxk卷積的合并
k_1x1_avg_merged, b_1x1_avg_merged = transIII_1x1_kxk(k_1x1_avg_first,
b_1x1_avg_first,
k_1x1_avg_second,
b_1x1_avg_second,
groups=self.groups)
else:
k_1x1_avg_merged, b_1x1_avg_merged = k_1x1_avg_second, b_1x1_avg_second
# 按照方式2進(jìn)行分支的合并
return transII_addbranch((k_origin,
k_1x1,
k_1x1_kxk_merged,
k_1x1_avg_merged),
(b_origin,
b_1x1,
b_1x1_kxk_merged,
b_1x1_avg_merged))
def switch_to_deploy(self):
if hasattr(self, 'dbb_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.dbb_reparam = nn.Conv2d(in_channels=self.dbb_origin.conv.in_channels,
out_channels=self.dbb_origin.conv.out_channels,
kernel_size=self.dbb_origin.conv.kernel_size,
stride=self.dbb_origin.conv.stride,
padding=self.dbb_origin.conv.padding,
dilation=self.dbb_origin.conv.dilation,
groups=self.dbb_origin.conv.groups, bias=True)
self.dbb_reparam.weight.data = kernel
self.dbb_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('dbb_origin')
self.__delattr__('dbb_avg')
if hasattr(self, 'dbb_1x1'):
self.__delattr__('dbb_1x1')
self.__delattr__('dbb_1x1_kxk')
def forward(self, inputs):
if hasattr(self, 'dbb_reparam'):
return self.nonlinear(self.dbb_reparam(inputs))
out = self.dbb_origin(inputs)
if hasattr(self, 'dbb_1x1'):
out += self.dbb_1x1(inputs)
out += self.dbb_avg(inputs)
out += self.dbb_1x1_kxk(inputs)
return self.nonlinear(out)
def init_gamma(self, gamma_value):
if hasattr(self, "dbb_origin"):
torch.nn.init.constant_(self.dbb_origin.bn.weight,
gamma_value)
if hasattr(self, "dbb_1x1"):
torch.nn.init.constant_(self.dbb_1x1.bn.weight,
gamma_value)
if hasattr(self, "dbb_avg"):
torch.nn.init.constant_(self.dbb_avg.avgbn.weight,
gamma_value)
if hasattr(self, "dbb_1x1_kxk"):
torch.nn.init.constant_(self.dbb_1x1_kxk.bn2.weight,
gamma_value)
def single_init(self):
self.init_gamma(0.0)
if hasattr(self, "dbb_origin"):
torch.nn.init.constant_(self.dbb_origin.bn.weight, 1.0)
話不多說,直接對比ONNX的輸出,就問你香不香?。?!
4、參考[1].https://github.com/DingXiaoH/DiverseBranchBlock/blob/main/diversebranchblock.py
本文僅做學(xué)術(shù)分享,如有侵權(quán),請聯(lián)系刪文。
*博客內(nèi)容為網(wǎng)友個(gè)人發(fā)布,僅代表博主個(gè)人觀點(diǎn),如有侵權(quán)請聯(lián)系工作人員刪除。