in downstream/semseg/ddp_main.py [0:0]
def main(config, init_distributed=False):
if not torch.cuda.is_available():
raise Exception('No GPUs FOUND.')
# setup initial seed
torch.cuda.set_device(config.distributed.device_id)
torch.manual_seed(config.misc.seed)
torch.cuda.manual_seed(config.misc.seed)
device = config.distributed.device_id
distributed = config.distributed.distributed_world_size > 1
if init_distributed:
config.distributed.distributed_rank = distributed_utils.distributed_init(config.distributed)
setup_logging(config)
logging.info('===> Configurations')
logging.info(config.pretty())
DatasetClass = load_dataset(config.data.dataset)
if config.test.test_original_pointcloud:
if not DatasetClass.IS_FULL_POINTCLOUD_EVAL:
raise ValueError('This dataset does not support full pointcloud evaluation.')
if config.test.evaluate_original_pointcloud:
if not config.data.return_transformation:
raise ValueError('Pointcloud evaluation requires config.return_transformation=true.')
if (config.data.return_transformation ^ config.test.evaluate_original_pointcloud):
raise ValueError('Rotation evaluation requires config.evaluate_original_pointcloud=true and '
'config.return_transformation=true.')
logging.info('===> Initializing dataloader')
if config.train.is_train:
train_data_loader = initialize_data_loader(
DatasetClass,
config,
phase=config.train.train_phase,
num_workers=config.data.num_workers,
augment_data=True,
shuffle=True,
repeat=True,
batch_size=config.data.batch_size,
limit_numpoints=config.data.train_limit_numpoints)
val_data_loader = initialize_data_loader(
DatasetClass,
config,
num_workers=config.data.num_val_workers,
phase=config.train.val_phase,
augment_data=False,
shuffle=True,
repeat=False,
batch_size=config.data.val_batch_size,
limit_numpoints=False)
if train_data_loader.dataset.NUM_IN_CHANNEL is not None:
num_in_channel = train_data_loader.dataset.NUM_IN_CHANNEL
else:
num_in_channel = 3 # RGB color
num_labels = train_data_loader.dataset.NUM_LABELS
else:
test_data_loader = initialize_data_loader(
DatasetClass,
config,
num_workers=config.data.num_workers,
phase=config.data.test_phase,
augment_data=False,
shuffle=False,
repeat=False,
batch_size=config.data.test_batch_size,
limit_numpoints=False)
if test_data_loader.dataset.NUM_IN_CHANNEL is not None:
num_in_channel = test_data_loader.dataset.NUM_IN_CHANNEL
else:
num_in_channel = 3 # RGB color
num_labels = test_data_loader.dataset.NUM_LABELS
logging.info('===> Building model')
NetClass = load_model(config.net.model)
if config.net.wrapper_type == None:
model = NetClass(num_in_channel, num_labels, config)
logging.info('===> Number of trainable parameters: {}: {}'.format(NetClass.__name__,
count_parameters(model)))
else:
wrapper = load_wrapper(config.net.wrapper_type)
model = wrapper(NetClass, num_in_channel, num_labels, config)
logging.info('===> Number of trainable parameters: {}: {}'.format(
wrapper.__name__ + NetClass.__name__, count_parameters(model)))
logging.info(model)
if config.net.weights == 'modelzoo': # Load modelzoo weights if possible.
logging.info('===> Loading modelzoo weights')
model.preload_modelzoo()
# Load weights if specified by the parameter.
elif config.net.weights.lower() != 'none':
logging.info('===> Loading weights: ' + config.net.weights)
# state = torch.load(config.weights)
state = torch.load(config.net.weights, map_location=lambda s, l: default_restore_location(s, 'cpu'))
if 'state_dict' in state.keys():
state_key_name = 'state_dict'
elif 'model_state' in state.keys():
state_key_name = 'model_state'
else:
raise NotImplementedError
if config.net.weights_for_inner_model:
model.model.load_state_dict(state['state_dict'])
else:
if config.train.lenient_weight_loading:
matched_weights = load_state_with_same_shape(model, state[state_key_name])
model_dict = model.state_dict()
model_dict.update(matched_weights)
model.load_state_dict(model_dict)
else:
model.load_state_dict(state['state_dict'])
model = model.cuda()
if distributed:
model = torch.nn.parallel.DistributedDataParallel(
module=model, device_ids=[device], output_device=device,
broadcast_buffers=False, bucket_cap_mb=config.distributed.bucket_cap_mb
)
if config.train.is_train:
train(model, train_data_loader, val_data_loader, config)
else:
test(model, test_data_loader, config)