def main()

in downstream/votenet_det_new/ddp_main.py [0:0]


def main(config):
      # load the configurations
    setup_logging()
    if os.path.exists('config.yaml'):
        logging.info('===> Loading exsiting config file')
        config = OmegaConf.load('config.yaml')
        logging.info('===> Loaded exsiting config file')
    logging.info(config.pretty())

    # Create Dataset and Dataloader
    if config.data.dataset == 'sunrgbd':
        from lib.datasets.sunrgbd.sunrgbd_detection_dataset import SunrgbdDetectionVotesDataset, MAX_NUM_OBJ
        from lib.datasets.sunrgbd.model_util_sunrgbd import SunrgbdDatasetConfig
        dataset_config = SunrgbdDatasetConfig()
        train_dataset = SunrgbdDetectionVotesDataset('train', 
            num_points=config.data.num_points,
            augment=True,
            use_color=config.data.use_color, 
            use_height=(not config.data.no_height),
            use_v1=(not config.data.use_sunrgbd_v2),
            data_ratio=config.data.data_ratio)
        test_dataset = SunrgbdDetectionVotesDataset('val', 
            num_points=config.data.num_points,
            augment=False,
            use_color=config.data.use_color, 
            use_height=(not config.data.no_height),
            use_v1=(not config.data.use_sunrgbd_v2))
    elif config.data.dataset == 'scannet':
        from lib.datasets.scannet.scannet_detection_dataset import ScannetDetectionDataset, MAX_NUM_OBJ
        from lib.datasets.scannet.model_util_scannet import ScannetDatasetConfig
        dataset_config = ScannetDatasetConfig()
        train_dataset = ScannetDetectionDataset('train', 
            num_points=config.data.num_points,
            augment=True,
            use_color=config.data.use_color, 
            use_height=(not config.data.no_height),
            data_ratio=config.data.data_ratio)
        test_dataset = ScannetDetectionDataset('val', 
            num_points=config.data.num_points,
            augment=False,
            use_color=config.data.use_color, 
            use_height=(not config.data.no_height))
    else:
        logging.info('Unknown dataset %s. Exiting...'%(config.data.dataset))
        exit(-1)

    COLLATE_FN = None
    if config.data.voxelization:
        from models.backbone.sparseconv.voxelized_dataset import VoxelizationDataset, collate_fn
        train_dataset = VoxelizationDataset(train_dataset, config.data.voxel_size)
        test_dataset = VoxelizationDataset(test_dataset, config.data.voxel_size)
        COLLATE_FN = collate_fn

    logging.info('training: {}, testing: {}'.format(len(train_dataset), len(test_dataset)))

    train_dataloader = DataLoader(
        train_dataset, 
        batch_size=config.data.batch_size,
        shuffle=True, 
        num_workers=config.data.num_workers, 
        worker_init_fn=my_worker_init_fn,
        collate_fn=COLLATE_FN)

    test_dataloader = DataLoader(
        test_dataset, 
        batch_size=config.data.num_workers,
        shuffle=True, 
        num_workers=config.data.num_workers, 
        worker_init_fn=my_worker_init_fn,
        collate_fn=COLLATE_FN)

    logging.info('train dataloader: {}, test dataloader: {}'.format(len(train_dataloader),len(test_dataloader)))

    # Init the model and optimzier
    MODEL = importlib.import_module('models.' + config.net.model) # import network module
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    num_input_channel = int(config.data.use_color)*3 + int(not config.data.no_height)*1

    if config.net.model == 'boxnet':
        Detector = MODEL.BoxNet
    else:
        Detector = MODEL.VoteNet

    net = Detector(num_class=dataset_config.num_class,
                   num_heading_bin=dataset_config.num_heading_bin,
                   num_size_cluster=dataset_config.num_size_cluster,
                   mean_size_arr=dataset_config.mean_size_arr,
                   num_proposal=config.net.num_target,
                   input_feature_dim=num_input_channel,
                   vote_factor=config.net.vote_factor,
                   sampling=config.net.cluster_sampling,
                   backbone=config.net.backbone)

    if config.net.weights is not None:
        assert config.net.backbone == "sparseconv", "only support sparseconv"
        print('===> Loading weights: ' + config.net.weights)
        state = torch.load(config.net.weights, map_location=lambda s, l: default_restore_location(s, 'cpu'))
        model = net
        if config.net.is_train:
            model = net.backbone_net.net
        matched_weights = load_state_with_same_shape(model, state['state_dict'])
        model_dict = model.state_dict()
        model_dict.update(matched_weights)
        model.load_state_dict(model_dict)

        # from pdb import set_trace; set_trace()

    net.to(device)

    if config.net.is_train:
        train(net, train_dataloader, test_dataloader, dataset_config, config)
    else:
        test(net, test_dataloader, dataset_config, config)