in scripts/train_instance_seg.py [0:0]
def main(args):
# Initialize multi-processing
distributed.init_process_group(backend='nccl', init_method='env://')
device_id, device = args.local_rank, torch.device(args.local_rank)
rank, world_size = distributed.get_rank(), distributed.get_world_size()
torch.cuda.set_device(device_id)
# Initialize logging
if rank == 0:
logging.init(args.log_dir, "training" if not args.eval else "eval")
summary = tensorboard.SummaryWriter(args.log_dir)
else:
summary = None
# Load configuration
config = make_config(args)
# Create dataloaders
train_dataloader, val_dataloader = make_dataloader(args, config, rank, world_size)
# Create model
model = make_model(config, train_dataloader.dataset.num_thing, train_dataloader.dataset.num_stuff)
if args.resume:
assert not args.pre_train, "resume and pre_train are mutually exclusive"
log_debug("Loading snapshot from %s", args.resume)
snapshot = resume_from_snapshot(model, args.resume, ["body", "rpn_head", "roi_head"])
elif args.pre_train:
assert not args.resume, "resume and pre_train are mutually exclusive"
log_debug("Loading pre-trained model from %s", args.pre_train)
pre_train_from_snapshots(model, args.pre_train, ["body", "rpn_head", "roi_head"])
else:
assert not args.eval, "--resume is needed in eval mode"
snapshot = None
# Init GPU stuff
torch.backends.cudnn.benchmark = config["general"].getboolean("cudnn_benchmark")
model = DistributedDataParallel(model.cuda(device), device_ids=[device_id], output_device=device_id,
find_unused_parameters=True)
# Create optimizer
optimizer, scheduler, batch_update, total_epochs = make_optimizer(config, model, len(train_dataloader))
if args.resume:
optimizer.load_state_dict(snapshot["state_dict"]["optimizer"])
# Training loop
momentum = 1. - 1. / len(train_dataloader)
meters = {
"loss": AverageMeter((), momentum),
"obj_loss": AverageMeter((), momentum),
"bbx_loss": AverageMeter((), momentum),
"roi_cls_loss": AverageMeter((), momentum),
"roi_bbx_loss": AverageMeter((), momentum),
"roi_msk_loss": AverageMeter((), momentum)
}
if args.resume:
starting_epoch = snapshot["training_meta"]["epoch"] + 1
best_score = snapshot["training_meta"]["best_score"]
global_step = snapshot["training_meta"]["global_step"]
for name, meter in meters.items():
meter.load_state_dict(snapshot["state_dict"][name + "_meter"])
del snapshot
else:
starting_epoch = 0
best_score = 0
global_step = 0
# Optional: evaluation only:
if args.eval:
log_info("Validating epoch %d", starting_epoch - 1)
validate(model, val_dataloader, config["optimizer"].getstruct("loss_weights"),
device=device, summary=summary, global_step=global_step,
epoch=starting_epoch - 1, num_epochs=total_epochs,
log_interval=config["general"].getint("log_interval"),
coco_gt=config["dataloader"]["coco_gt"], log_dir=args.log_dir)
exit(0)
for epoch in range(starting_epoch, total_epochs):
log_info("Starting epoch %d", epoch + 1)
if not batch_update:
scheduler.step(epoch)
# Run training epoch
global_step = train(model, optimizer, scheduler, train_dataloader, meters,
batch_update=batch_update, epoch=epoch, summary=summary, device=device,
log_interval=config["general"].getint("log_interval"), num_epochs=total_epochs,
global_step=global_step, loss_weights=config["optimizer"].getstruct("loss_weights"))
# Save snapshot (only on rank 0)
if rank == 0:
snapshot_file = path.join(args.log_dir, "model_last.pth.tar")
log_debug("Saving snapshot to %s", snapshot_file)
meters_out_dict = {k + "_meter": v.state_dict() for k, v in meters.items()}
save_snapshot(snapshot_file, config, epoch, 0, best_score, global_step,
body=model.module.body.state_dict(),
rpn_head=model.module.rpn_head.state_dict(),
roi_head=model.module.roi_head.state_dict(),
optimizer=optimizer.state_dict(),
**meters_out_dict)
if (epoch + 1) % config["general"].getint("val_interval") == 0:
log_info("Validating epoch %d", epoch + 1)
score = validate(model, val_dataloader, config["optimizer"].getstruct("loss_weights"),
device=device, summary=summary, global_step=global_step,
epoch=epoch, num_epochs=total_epochs,
log_interval=config["general"].getint("log_interval"),
coco_gt=config["dataloader"]["coco_gt"], log_dir=args.log_dir)
# Update the score on the last saved snapshot
if rank == 0:
snapshot = torch.load(snapshot_file, map_location="cpu")
snapshot["training_meta"]["last_score"] = score
torch.save(snapshot, snapshot_file)
del snapshot
if score > best_score:
best_score = score
if rank == 0:
shutil.copy(snapshot_file, path.join(args.log_dir, "model_best.pth.tar"))