in downstream/votenet/lib/ddp_trainer.py [0:0]
def __init__(self, config):
self.is_master = is_master_proc(get_world_size()) if get_world_size() > 1 else True
self.cur_device = torch.cuda.current_device()
# load the configurations
self.setup_logging()
if os.path.exists('config.yaml'):
logging.info('===> Loading exsiting config file')
config = OmegaConf.load('config.yaml')
logging.info('===> Loaded exsiting config file')
logging.info('===> Configurations')
logging.info(config.pretty())
# Create Dataset and Dataloader
if config.data.dataset == 'sunrgbd':
from datasets.sunrgbd.sunrgbd_detection_dataset import SunrgbdDetectionVotesDataset, MAX_NUM_OBJ
from datasets.sunrgbd.model_util_sunrgbd import SunrgbdDatasetConfig
dataset_config = SunrgbdDatasetConfig()
train_dataset = SunrgbdDetectionVotesDataset('train',
num_points=config.data.num_points,
augment=True,
use_color=config.data.use_color,
use_height=(not config.data.no_height),
use_v1=(not config.data.use_sunrgbd_v2))
test_dataset = SunrgbdDetectionVotesDataset(config.test.phase,
num_points=config.data.num_points,
augment=False,
use_color=config.data.use_color,
use_height=(not config.data.no_height),
use_v1=(not config.data.use_sunrgbd_v2))
elif config.data.dataset == 'scannet':
from datasets.scannet.scannet_detection_dataset import ScannetDetectionDataset, MAX_NUM_OBJ
from datasets.scannet.model_util_scannet import ScannetDatasetConfig
dataset_config = ScannetDatasetConfig()
train_dataset = ScannetDetectionDataset('train',
num_points=config.data.num_points,
augment=True,
use_color=config.data.use_color,
use_height=(not config.data.no_height),
by_scenes=config.data.by_scenes,
by_points=config.data.by_points)
test_dataset = ScannetDetectionDataset(config.test.phase,
num_points=config.data.num_points,
augment=False,
use_color=config.data.use_color,
use_height=(not config.data.no_height))
else:
logging.info('Unknown dataset %s. Exiting...'%(config.data.dataset))
exit(-1)
COLLATE_FN = None
if config.data.voxelization:
from models.backbone.sparseconv.voxelized_dataset import VoxelizationDataset, collate_fn
train_dataset = VoxelizationDataset(train_dataset, config.data.voxel_size)
test_dataset = VoxelizationDataset(test_dataset, config.data.voxel_size)
COLLATE_FN = collate_fn
logging.info('training: {}, testing: {}'.format(len(train_dataset), len(test_dataset)))
self.sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if get_world_size() > 1 else None
train_dataloader = DataLoader(
train_dataset,
batch_size=config.data.batch_size // config.misc.num_gpus,
shuffle=(self.sampler is None),
sampler=self.sampler,
num_workers=config.data.num_workers,
collate_fn=COLLATE_FN)
test_dataloader = DataLoader(
test_dataset,
batch_size=1,
shuffle=False,
num_workers=1,
collate_fn=COLLATE_FN)
logging.info('train dataloader: {}, test dataloader: {}'.format(len(train_dataloader),len(test_dataloader)))
# Init the model and optimzier
MODEL = importlib.import_module('models.' + config.net.model) # import network module
num_input_channel = int(config.data.use_color)*3 + int(not config.data.no_height)*1
if config.net.model == 'boxnet':
Detector = MODEL.BoxNet
else:
Detector = MODEL.VoteNet
net = Detector(num_class=dataset_config.num_class,
num_heading_bin=dataset_config.num_heading_bin,
num_size_cluster=dataset_config.num_size_cluster,
mean_size_arr=dataset_config.mean_size_arr,
num_proposal=config.net.num_target,
input_feature_dim=num_input_channel,
vote_factor=config.net.vote_factor,
sampling=config.net.cluster_sampling,
backbone=config.net.backbone)
if config.net.weights != '':
#assert config.net.backbone == "sparseconv", "only support sparseconv"
print('===> Loading weights: ' + config.net.weights)
state = torch.load(config.net.weights, map_location=lambda s, l: default_restore_location(s, 'cpu'))
model = net
if config.net.is_train:
model = net.backbone_net
if config.net.backbone == "sparseconv":
model = net.backbone_net.net
matched_weights = DetectionTrainer.load_state_with_same_shape(model, state['state_dict'])
model_dict = model.state_dict()
model_dict.update(matched_weights)
model.load_state_dict(model_dict)
net.to(self.cur_device)
if get_world_size() > 1:
net = torch.nn.parallel.DistributedDataParallel(
module=net, device_ids=[self.cur_device], output_device=self.cur_device, broadcast_buffers=False)
# Load the Adam optimizer
self.optimizer = optim.Adam(net.parameters(), lr=config.optimizer.learning_rate, weight_decay=config.optimizer.weight_decay)
# writer
if self.is_master:
self.writer = SummaryWriter(log_dir='tensorboard')
self.config = config
self.dataset_config = dataset_config
self.net = net
self.train_dataloader = train_dataloader
self.test_dataloader = test_dataloader
self.best_mAP = -1
# Used for AP calculation
self.CONFIG_DICT = {'remove_empty_box':False, 'use_3d_nms':True,
'nms_iou':0.25, 'use_old_type_nms':False, 'cls_nms':True,
'per_class_proposal': True, 'conf_thresh':0.05, 'dataset_config': dataset_config}
# Used for AP calculation
self.CONFIG_DICT_TEST = {'remove_empty_box': (not config.test.faster_eval),
'use_3d_nms': config.test.use_3d_nms,
'nms_iou': config.test.nms_iou,
'use_old_type_nms': config.test.use_old_type_nms,
'cls_nms': config.test.use_cls_nms,
'per_class_proposal': config.test.per_class_proposal,
'conf_thresh': config.test.conf_thresh,
'dataset_config': dataset_config}
# Load checkpoint if there is any
self.start_epoch = 0
CHECKPOINT_PATH = os.path.join('checkpoint.tar')
if os.path.isfile(CHECKPOINT_PATH):
checkpoint = torch.load(CHECKPOINT_PATH)
if get_world_size() > 1:
_model = self.net.module
else:
_model = self.net
_model.load_state_dict(checkpoint['state_dict'])
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
self.start_epoch = checkpoint['epoch']
self.best_mAP = checkpoint['best_mAP']
logging.info("-> loaded checkpoint %s (epoch: %d)"%(CHECKPOINT_PATH, self.start_epoch))
# Decay Batchnorm momentum from 0.5 to 0.999
# note: pytorch's BN momentum (default 0.1)= 1 - tensorflow's BN momentum
BN_MOMENTUM_INIT = 0.5
BN_MOMENTUM_MAX = 0.001
BN_DECAY_STEP = config.optimizer.bn_decay_step
BN_DECAY_RATE = config.optimizer.bn_decay_rate
bn_lbmd = lambda it: max(BN_MOMENTUM_INIT * BN_DECAY_RATE**(int(it / BN_DECAY_STEP)), BN_MOMENTUM_MAX)
self.bnm_scheduler = BNMomentumScheduler(net, bn_lambda=bn_lbmd, last_epoch=self.start_epoch-1)