def __init__()

in downstream/votenet/lib/ddp_trainer.py [0:0]


    def __init__(self, config):
        self.is_master = is_master_proc(get_world_size()) if get_world_size() > 1 else True
        self.cur_device = torch.cuda.current_device()

        # load the configurations
        self.setup_logging()
        if os.path.exists('config.yaml'):
            logging.info('===> Loading exsiting config file')
            config = OmegaConf.load('config.yaml')
            logging.info('===> Loaded exsiting config file')
        logging.info('===> Configurations')
        logging.info(config.pretty())

        # Create Dataset and Dataloader
        if config.data.dataset == 'sunrgbd':
            from datasets.sunrgbd.sunrgbd_detection_dataset import SunrgbdDetectionVotesDataset, MAX_NUM_OBJ
            from datasets.sunrgbd.model_util_sunrgbd import SunrgbdDatasetConfig
            dataset_config = SunrgbdDatasetConfig()
            train_dataset = SunrgbdDetectionVotesDataset('train', 
                num_points=config.data.num_points,
                augment=True,
                use_color=config.data.use_color, 
                use_height=(not config.data.no_height),
                use_v1=(not config.data.use_sunrgbd_v2))
            test_dataset = SunrgbdDetectionVotesDataset(config.test.phase, 
                num_points=config.data.num_points,
                augment=False,
                use_color=config.data.use_color, 
                use_height=(not config.data.no_height),
                use_v1=(not config.data.use_sunrgbd_v2))
        elif config.data.dataset == 'scannet':
            from datasets.scannet.scannet_detection_dataset import ScannetDetectionDataset, MAX_NUM_OBJ
            from datasets.scannet.model_util_scannet import ScannetDatasetConfig
            dataset_config = ScannetDatasetConfig()
            train_dataset = ScannetDetectionDataset('train', 
                num_points=config.data.num_points,
                augment=True,
                use_color=config.data.use_color, 
                use_height=(not config.data.no_height),
                by_scenes=config.data.by_scenes,
                by_points=config.data.by_points)

            test_dataset = ScannetDetectionDataset(config.test.phase, 
                num_points=config.data.num_points,
                augment=False,
                use_color=config.data.use_color, 
                use_height=(not config.data.no_height))
        else:
            logging.info('Unknown dataset %s. Exiting...'%(config.data.dataset))
            exit(-1)

        COLLATE_FN = None
        if config.data.voxelization:
            from models.backbone.sparseconv.voxelized_dataset import VoxelizationDataset, collate_fn
            train_dataset = VoxelizationDataset(train_dataset, config.data.voxel_size)
            test_dataset = VoxelizationDataset(test_dataset, config.data.voxel_size)
            COLLATE_FN = collate_fn
        logging.info('training: {}, testing: {}'.format(len(train_dataset), len(test_dataset)))

        self.sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if get_world_size() > 1 else None
        train_dataloader = DataLoader(
            train_dataset, 
            batch_size=config.data.batch_size // config.misc.num_gpus,
            shuffle=(self.sampler is None),
            sampler=self.sampler,
            num_workers=config.data.num_workers, 
            collate_fn=COLLATE_FN)

        test_dataloader = DataLoader(
            test_dataset, 
            batch_size=1,
            shuffle=False, 
            num_workers=1, 
            collate_fn=COLLATE_FN)
        logging.info('train dataloader: {}, test dataloader: {}'.format(len(train_dataloader),len(test_dataloader)))

        # Init the model and optimzier
        MODEL = importlib.import_module('models.' + config.net.model) # import network module
        num_input_channel = int(config.data.use_color)*3 + int(not config.data.no_height)*1

        if config.net.model == 'boxnet':
            Detector = MODEL.BoxNet
        else:
            Detector = MODEL.VoteNet

        net = Detector(num_class=dataset_config.num_class,
                    num_heading_bin=dataset_config.num_heading_bin,
                    num_size_cluster=dataset_config.num_size_cluster,
                    mean_size_arr=dataset_config.mean_size_arr,
                    num_proposal=config.net.num_target,
                    input_feature_dim=num_input_channel,
                    vote_factor=config.net.vote_factor,
                    sampling=config.net.cluster_sampling,
                    backbone=config.net.backbone)

        if config.net.weights != '':
            #assert config.net.backbone == "sparseconv", "only support sparseconv"
            print('===> Loading weights: ' + config.net.weights)
            state = torch.load(config.net.weights, map_location=lambda s, l: default_restore_location(s, 'cpu'))
            model = net
            if config.net.is_train:
                model = net.backbone_net
                if config.net.backbone == "sparseconv":
                    model = net.backbone_net.net
                    
            matched_weights = DetectionTrainer.load_state_with_same_shape(model, state['state_dict'])
            model_dict = model.state_dict()
            model_dict.update(matched_weights)
            model.load_state_dict(model_dict)

        net.to(self.cur_device)
        if get_world_size() > 1:
            net = torch.nn.parallel.DistributedDataParallel(
            module=net, device_ids=[self.cur_device], output_device=self.cur_device, broadcast_buffers=False) 

        # Load the Adam optimizer
        self.optimizer = optim.Adam(net.parameters(), lr=config.optimizer.learning_rate, weight_decay=config.optimizer.weight_decay)
        # writer
        if self.is_master:
            self.writer = SummaryWriter(log_dir='tensorboard')
        self.config = config
        self.dataset_config = dataset_config
        self.net = net
        self.train_dataloader = train_dataloader
        self.test_dataloader = test_dataloader
        self.best_mAP = -1

        # Used for AP calculation
        self.CONFIG_DICT = {'remove_empty_box':False, 'use_3d_nms':True,
                            'nms_iou':0.25, 'use_old_type_nms':False, 'cls_nms':True,
                            'per_class_proposal': True, 'conf_thresh':0.05, 'dataset_config': dataset_config}

        # Used for AP calculation
        self.CONFIG_DICT_TEST = {'remove_empty_box': (not config.test.faster_eval), 
                                 'use_3d_nms': config.test.use_3d_nms, 
                                 'nms_iou': config.test.nms_iou,
                                 'use_old_type_nms': config.test.use_old_type_nms, 
                                 'cls_nms': config.test.use_cls_nms, 
                                 'per_class_proposal': config.test.per_class_proposal,
                                 'conf_thresh': config.test.conf_thresh,
                                 'dataset_config': dataset_config}

        # Load checkpoint if there is any
        self.start_epoch = 0
        CHECKPOINT_PATH = os.path.join('checkpoint.tar')
        if os.path.isfile(CHECKPOINT_PATH):
            checkpoint = torch.load(CHECKPOINT_PATH)
            if get_world_size() > 1:
                _model = self.net.module
            else:
                _model = self.net
            _model.load_state_dict(checkpoint['state_dict'])
            self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
            self.start_epoch = checkpoint['epoch']
            self.best_mAP = checkpoint['best_mAP']
            logging.info("-> loaded checkpoint %s (epoch: %d)"%(CHECKPOINT_PATH, self.start_epoch))

        # Decay Batchnorm momentum from 0.5 to 0.999
        # note: pytorch's BN momentum (default 0.1)= 1 - tensorflow's BN momentum
        BN_MOMENTUM_INIT = 0.5
        BN_MOMENTUM_MAX = 0.001
        BN_DECAY_STEP = config.optimizer.bn_decay_step
        BN_DECAY_RATE = config.optimizer.bn_decay_rate
        bn_lbmd = lambda it: max(BN_MOMENTUM_INIT * BN_DECAY_RATE**(int(it / BN_DECAY_STEP)), BN_MOMENTUM_MAX)
        self.bnm_scheduler = BNMomentumScheduler(net, bn_lambda=bn_lbmd, last_epoch=self.start_epoch-1)