def __init__()

in crlapi/sl/architectures/firefly_vgg/sp/conv.py [0:0]


    def __init__(self,
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=1,
            groups = 1,
            can_split=True,
            bias=True,
            actv_fn='relu',
            has_bn=False,
            rescale=1.0):

        super().__init__(can_split=can_split,
                         actv_fn=actv_fn,
                         has_bn=has_bn,
                         has_bias=bias,
                         rescale=rescale)

        if has_bn:
            self.bn = nn.BatchNorm2d(out_channels)
            self.has_bias = False

        if isinstance(kernel_size, int):
            self.kh = self.kw = kernel_size
        else:
            assert len(kernel_size) == 2
            self.kh, self.kw = kernel_size

        if isinstance(padding, int):
            self.ph = self.pw = padding
        else:
            assert len(padding) == 2
            self.ph, self.pw = padding

        if isinstance(stride, int):
            self.dh = self.dw = stride
        else:
            assert len(stride) == 2
            self.dh, self.dw = stride
        self.groups = groups
        self.module = nn.Conv2d(in_channels=in_channels,
                                out_channels=out_channels,
                                kernel_size=kernel_size,
                                groups = groups,
                                stride=stride,
                                padding=padding,
                                bias=self.has_bias)