network/resnet101_3d_gcn_x5.py [175:212]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            logging.info("Network:: symbol initialized, use pretrained model: `{}'".format(pretrained_model))
            assert os.path.exists(pretrained_model), "cannot locate: `{}'".format(pretrained_model)
            state_dict_2d = torch.load(pretrained_model)
            initializer.init_3d_from_2d_dict(net=self, state_dict=state_dict_2d, method=load_method)
        else:
            logging.info("Network:: symbol initialized, use random inilization!")

        blocker_name_list = []
        for name, param in self.state_dict().items():
            if name.endswith('blocker.weight'):
                blocker_name_list.append(name)
                param[:] = 0.
        if len(blocker_name_list) > 0:
            logging.info("Network:: change params of the following layer be zeros: {}".format(blocker_name_list))


    def forward(self, x):
        assert x.shape[2] == 8

        h = self.conv1(x)   # x112 ->  x56
        h = self.conv2(h)   #  x56 ->  x56
        h = self.conv3(h)   #  x56 ->  x28
        h = self.conv4(h)   #  x28 ->  x14
        h = self.conv5(h)   #  x14 ->   x7

        # logging.info("{}".format(h.shape))

        h = self.tail(h)
        h = self.globalpool(h)

        h = h.view(h.shape[0], -1)
        h = self.classifier(h)

        return h

if __name__ == "__main__":
    import torch
    logging.getLogger().setLevel(logging.DEBUG)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



network/resnet50_3d_gcn_x5.py [175:212]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            logging.info("Network:: symbol initialized, use pretrained model: `{}'".format(pretrained_model))
            assert os.path.exists(pretrained_model), "cannot locate: `{}'".format(pretrained_model)
            state_dict_2d = torch.load(pretrained_model)
            initializer.init_3d_from_2d_dict(net=self, state_dict=state_dict_2d, method=load_method)
        else:
            logging.info("Network:: symbol initialized, use random inilization!")

        blocker_name_list = []
        for name, param in self.state_dict().items():
            if name.endswith('blocker.weight'):
                blocker_name_list.append(name)
                param[:] = 0. 
        if len(blocker_name_list) > 0:
            logging.info("Network:: change params of the following layer be zeros: {}".format(blocker_name_list))


    def forward(self, x):
        assert x.shape[2] == 8

        h = self.conv1(x)   # x112 ->  x56
        h = self.conv2(h)   #  x56 ->  x56
        h = self.conv3(h)   #  x56 ->  x28
        h = self.conv4(h)   #  x28 ->  x14
        h = self.conv5(h)   #  x14 ->   x7

        # logging.info("{}".format(h.shape))

        h = self.tail(h)
        h = self.globalpool(h)

        h = h.view(h.shape[0], -1)
        h = self.classifier(h)

        return h

if __name__ == "__main__":
    import torch
    logging.getLogger().setLevel(logging.DEBUG)
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



