in references/detection/train.py [0:0]
def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
device = torch.device(args.device)
# Data loading code
print("Loading data")
dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path)
dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path)
print("Creating data loaders")
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
if args.aspect_ratio_group_factor >= 0:
group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
else:
train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
data_loader = torch.utils.data.DataLoader(
dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
)
print("Creating model")
kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
if "rcnn" in args.model:
if args.rpn_score_thresh is not None:
kwargs["rpn_score_thresh"] = args.rpn_score_thresh
if not args.prototype:
model = torchvision.models.detection.__dict__[args.model](
pretrained=args.pretrained, num_classes=num_classes, **kwargs
)
else:
model = prototype.models.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
params = [p for p in model.parameters() if p.requires_grad]
optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
args.lr_scheduler = args.lr_scheduler.lower()
if args.lr_scheduler == "multisteplr":
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
elif args.lr_scheduler == "cosineannealinglr":
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
else:
raise RuntimeError(
f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
)
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only:
evaluate(model, data_loader_test, device=device)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
lr_scheduler.step()
if args.output_dir:
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"args": args,
"epoch": epoch,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
# evaluate after every epoch
evaluate(model, data_loader_test, device=device)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")