in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=1,
groups = 1,
can_split=True,
bias=True,
actv_fn='relu',
has_bn=False,
rescale=1.0):
super().__init__(can_split=can_split,
actv_fn=actv_fn,
has_bn=has_bn,
has_bias=bias,
rescale=rescale)
if has_bn:
self.bn = nn.BatchNorm2d(out_channels)
self.has_bias = False
if isinstance(kernel_size, int):
self.kh = self.kw = kernel_size
else:
assert len(kernel_size) == 2
self.kh, self.kw = kernel_size
if isinstance(padding, int):
self.ph = self.pw = padding
else:
assert len(padding) == 2
self.ph, self.pw = padding
if isinstance(stride, int):
self.dh = self.dw = stride
else:
assert len(stride) == 2
self.dh, self.dw = stride
self.groups = groups
self.module = nn.Conv2d(in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
groups = groups,
stride=stride,
padding=padding,
bias=self.has_bias)