def __init__()

in network/resnet50_3d_gcn_x5.py [0:0]


    def __init__(self, num_classes, pretrained=False, **kwargs):
        super(RESNET50_3D_GCN_X5, self).__init__()

        groups = 1
        k_sec  = {  2: 3, \
                    3: 4, \
                    4: 6, \
                    5: 3  }

        # conv1 - x112 (x16)
        conv1_num_out = 32
        self.conv1 = nn.Sequential(OrderedDict([
                    ('conv', nn.Conv3d( 3, conv1_num_out, kernel_size=(3,5,5), padding=(1,2,2), stride=(1,2,2), bias=False)),
                    ('bn', nn.BatchNorm3d(conv1_num_out, eps=1e-04)),
                    ('relu', nn.ReLU(inplace=True)),
                    ('max_pool', nn.MaxPool3d(kernel_size=(1,3,3), padding=(0,1,1), stride=(1,2,2))),
                    ]))

        # conv2 - x56 (x16)
        num_mid = 64
        conv2_num_out = 256
        self.conv2 = nn.Sequential(OrderedDict([
                    ("B%02d"%i, RESIDUAL_BLOCK(num_in=conv1_num_out if i==1 else conv2_num_out,
                                               num_mid=num_mid,
                                               num_out=conv2_num_out,
                                               stride=(1,1,1) if i==1 else (1,1,1),
                                               g=groups,
                                               first_block=(i==1))) for i in range(1,k_sec[2]+1)
                    ]))

        # conv3 - x28 (x8)
        num_mid *= 2
        conv3_num_out = 2 * conv2_num_out
        blocks = []
        for i in range(1,k_sec[3]+1):
            use_3d = bool(i % 2)
            blocks.append(("B%02d"%i, RESIDUAL_BLOCK(num_in=conv2_num_out if i==1 else conv3_num_out,
                                                     num_mid=num_mid,
                                                     num_out=conv3_num_out,
                                                     stride=(2,2,2) if i==1 else (1,1,1),
                                                     use_3d=use_3d,
                                                     g=groups,
                                                     first_block=(i==1))))
            if i in [1,3]:
                blocks.append(("B%02d_extra"%i, GloRe_Unit(num_in=conv3_num_out, num_mid=num_mid)))
        self.conv3 = nn.Sequential(OrderedDict(blocks))
        
        # conv4 - x14 (x8)
        num_mid *= 2
        conv4_num_out = 2 * conv3_num_out
        blocks = []
        for i in range(1,k_sec[4]+1):
            use_3d = bool(i % 2)
            blocks.append(("B%02d"%i, RESIDUAL_BLOCK(num_in=conv3_num_out if i==1 else conv4_num_out,
                                                     num_mid=num_mid,
                                                     num_out=conv4_num_out,
                                                     stride=(1,2,2) if i==1 else (1,1,1),
                                                     use_3d=use_3d,
                                                     g=groups,
                                                     first_block=(i==1))))
            if i in [1,3,5]:
                blocks.append(("B%02d_extra"%i, GloRe_Unit(num_in=conv4_num_out, num_mid=num_mid)))
        self.conv4 = nn.Sequential(OrderedDict(blocks))
        
        # conv5 - x7 (x4)
        num_mid *= 2
        conv5_num_out = 2 * conv4_num_out
        self.conv5 = nn.Sequential(OrderedDict([
                    ("B%02d"%i, RESIDUAL_BLOCK(num_in=conv4_num_out if i==1 else conv5_num_out,
                                               num_mid=num_mid,
                                               num_out=conv5_num_out,
                                               stride=(1,2,2) if i==1 else (1,1,1),
                                               g=groups,
                                               use_3d=(i==2),
                                               first_block=(i==1))) for i in range(1,k_sec[5]+1)
                    ]))

        # final
        self.tail = nn.Sequential(OrderedDict([
                    ('bn', nn.BatchNorm3d(conv5_num_out, eps=1e-04)),
                    ('relu', nn.ReLU(inplace=True))
                    ]))

        self.globalpool = nn.Sequential(OrderedDict([
                        ('avg', nn.AvgPool3d(kernel_size=(4,7,7),  stride=(1,1,1))),
                        ('dropout', nn.Dropout(p=0.5)),
                        ]))
        self.classifier = nn.Linear(conv5_num_out, num_classes)

        #############
        # Initialization
        initializer.xavier(net=self)

        if pretrained:
            import torch
            load_method='inflation' # 'random', 'inflation'
            pretrained_model=os.path.join(os.path.dirname(os.path.realpath(__file__)), 'pretrained/resnet50-lite.pth')
            logging.info("Network:: symbol initialized, use pretrained model: `{}'".format(pretrained_model))
            assert os.path.exists(pretrained_model), "cannot locate: `{}'".format(pretrained_model)
            state_dict_2d = torch.load(pretrained_model)
            initializer.init_3d_from_2d_dict(net=self, state_dict=state_dict_2d, method=load_method)
        else:
            logging.info("Network:: symbol initialized, use random inilization!")

        blocker_name_list = []
        for name, param in self.state_dict().items():
            if name.endswith('blocker.weight'):
                blocker_name_list.append(name)
                param[:] = 0. 
        if len(blocker_name_list) > 0:
            logging.info("Network:: change params of the following layer be zeros: {}".format(blocker_name_list))