network/resnet101_3d_gcn_x5.py [14:73]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
try:
    from . import initializer
    from .global_reasoning_unit import GloRe_Unit
except:
    import initializer
    from global_reasoning_unit import GloRe_Unit

class BN_AC_CONV2D(nn.Module):

    def __init__(self, num_in, num_filter,
                 kernel=(1,1), pad=(0,0), stride=(1,1), g=1, bias=False):
        super(BN_AC_CONV2D, self).__init__()
        self.bn = nn.BatchNorm2d(num_in, eps=1e-04)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(num_in, num_filter, kernel_size=kernel, padding=pad,
                               stride=stride, groups=g, bias=bias)

    def forward(self, x):
        h = self.relu(self.bn(x))
        h = self.conv(h)
        return h

class BN_AC_CONV3D(nn.Module):

    def __init__(self, num_in, num_filter,
                 kernel=(1,1,1), pad=(0,0,0), stride=(1,1,1), g=1, bias=False):
        super(BN_AC_CONV3D, self).__init__()
        self.bn = nn.BatchNorm3d(num_in, eps=1e-04)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv3d(num_in, num_filter, kernel_size=kernel, padding=pad,
                               stride=stride, groups=g, bias=bias)

    def forward(self, x):
        h = self.relu(self.bn(x))
        h = self.conv(h)
        return h

class RESIDUAL_BLOCK(nn.Module):

    def __init__(self, num_in, num_mid, num_out, g=1, stride=(1,1,1), first_block=False, use_3d=True):
        super(RESIDUAL_BLOCK, self).__init__()
        kt,pt = (3,1) if use_3d else (1,0)

        self.conv_m1 = BN_AC_CONV3D(num_in=num_in, num_filter=num_mid, kernel=(kt,1,1), pad=(pt,0,0))
        self.conv_m2 = BN_AC_CONV3D(num_in=num_mid, num_filter=num_mid, kernel=(1,3,3), pad=(0,1,1), stride=stride, g=g)
        self.conv_m3 = BN_AC_CONV3D(num_in=num_mid, num_filter=num_out, kernel=(1,1,1), pad=(0,0,0))
        # adapter
        if first_block:
            self.conv_w1 = BN_AC_CONV3D(num_in=num_in,  num_filter=num_out, kernel=(1,1,1), pad=(0,0,0), stride=stride)

    def forward(self, x):

        h = self.conv_m1(x)
        h = self.conv_m2(h)
        h = self.conv_m3(h)

        if hasattr(self, 'conv_w1'):
            x = self.conv_w1(x)

        return h + x
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



network/resnet50_3d_gcn_x5.py [14:73]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
try:
    from . import initializer
    from .global_reasoning_unit import GloRe_Unit
except:
    import initializer
    from global_reasoning_unit import GloRe_Unit

class BN_AC_CONV2D(nn.Module):

    def __init__(self, num_in, num_filter,
                 kernel=(1,1), pad=(0,0), stride=(1,1), g=1, bias=False):
        super(BN_AC_CONV2D, self).__init__()
        self.bn = nn.BatchNorm2d(num_in, eps=1e-04)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(num_in, num_filter, kernel_size=kernel, padding=pad,
                               stride=stride, groups=g, bias=bias)

    def forward(self, x):
        h = self.relu(self.bn(x))
        h = self.conv(h)
        return h

class BN_AC_CONV3D(nn.Module):

    def __init__(self, num_in, num_filter,
                 kernel=(1,1,1), pad=(0,0,0), stride=(1,1,1), g=1, bias=False):
        super(BN_AC_CONV3D, self).__init__()
        self.bn = nn.BatchNorm3d(num_in, eps=1e-04)
        self.relu = nn.ReLU(inplace=True)
        self.conv = nn.Conv3d(num_in, num_filter, kernel_size=kernel, padding=pad,
                               stride=stride, groups=g, bias=bias)

    def forward(self, x):
        h = self.relu(self.bn(x))
        h = self.conv(h)
        return h

class RESIDUAL_BLOCK(nn.Module):

    def __init__(self, num_in, num_mid, num_out, g=1, stride=(1,1,1), first_block=False, use_3d=True):
        super(RESIDUAL_BLOCK, self).__init__()
        kt,pt = (3,1) if use_3d else (1,0)

        self.conv_m1 = BN_AC_CONV3D(num_in=num_in, num_filter=num_mid, kernel=(kt,1,1), pad=(pt,0,0))
        self.conv_m2 = BN_AC_CONV3D(num_in=num_mid, num_filter=num_mid, kernel=(1,3,3), pad=(0,1,1), stride=stride, g=g)
        self.conv_m3 = BN_AC_CONV3D(num_in=num_mid, num_filter=num_out, kernel=(1,1,1), pad=(0,0,0))
        # adapter
        if first_block:
            self.conv_w1 = BN_AC_CONV3D(num_in=num_in,  num_filter=num_out, kernel=(1,1,1), pad=(0,0,0), stride=stride)

    def forward(self, x):

        h = self.conv_m1(x)
        h = self.conv_m2(h)
        h = self.conv_m3(h)

        if hasattr(self, 'conv_w1'):
            x = self.conv_w1(x)

        return h + x
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



