in projects_oss/detr/main.py [0:0]
def main(args):
# utils.init_distributed_mode(args)
if args.frozen_weights is not None:
assert args.masks, "Frozen training is meant for segmentation only"
print(args)
device = torch.device(args.device)
# fix the seed for reproducibility
seed = args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
model, criterion, postprocessors = build_model(args)
model.to(device)
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print("number of params:", n_parameters)
param_dicts = [
{
"params": [
p
for n, p in model_without_ddp.named_parameters()
if "backbone" not in n and p.requires_grad
]
},
{
"params": [
p
for n, p in model_without_ddp.named_parameters()
if "backbone" in n and p.requires_grad
],
"lr": args.lr_backbone,
},
]
optimizer = torch.optim.AdamW(
param_dicts, lr=args.lr, weight_decay=args.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, args.lr_drop)
dataset_train = build_dataset(image_set="train", args=args)
dataset_val = build_dataset(image_set="val", args=args)
if args.distributed:
sampler_train = DistributedSampler(dataset_train)
sampler_val = DistributedSampler(dataset_val, shuffle=False)
else:
sampler_train = torch.utils.data.RandomSampler(dataset_train)
sampler_val = torch.utils.data.SequentialSampler(dataset_val)
batch_sampler_train = torch.utils.data.BatchSampler(
sampler_train, args.batch_size, drop_last=True
)
data_loader_train = DataLoader(
dataset_train,
batch_sampler=batch_sampler_train,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
)
data_loader_val = DataLoader(
dataset_val,
args.batch_size,
sampler=sampler_val,
drop_last=False,
collate_fn=utils.collate_fn,
num_workers=args.num_workers,
)
if args.dataset_file == "coco_panoptic":
# We also evaluate AP during panoptic training, on original coco DS
coco_val = datasets.coco.build("val", args)
base_ds = get_coco_api_from_dataset(coco_val)
else:
base_ds = get_coco_api_from_dataset(dataset_val)
if args.frozen_weights is not None:
checkpoint = torch.load(args.frozen_weights, map_location="cpu")
model_without_ddp.detr.load_state_dict(checkpoint["model"])
if args.resume:
if args.resume.startswith("https"):
checkpoint = torch.hub.load_state_dict_from_url(
args.resume, map_location="cpu", check_hash=True
)
else:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
if (
not args.eval
and "optimizer" in checkpoint
and "lr_scheduler" in checkpoint
and "epoch" in checkpoint
):
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.eval:
test_stats, coco_evaluator = evaluate(
model,
criterion,
postprocessors,
data_loader_val,
base_ds,
device,
args.output_dir,
)
if args.output_dir:
with PathManager.open(os.path.join(args.output_dir, "eval.pth"), "wb") as f:
utils.save_on_master(coco_evaluator.coco_eval["bbox"].eval, f)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
sampler_train.set_epoch(epoch)
train_stats = train_one_epoch(
model,
criterion,
data_loader_train,
optimizer,
device,
epoch,
args.clip_max_norm,
)
lr_scheduler.step()
if args.output_dir:
checkpoint_paths = [] # os.path.join(args.output_dir, 'checkpoint.pth')]
# extra checkpoint before LR drop and every 10 epochs
if (epoch + 1) % args.lr_drop == 0 or (epoch + 1) % 10 == 0:
checkpoint_paths.append(
os.path.join(args.output_dir, f"checkpoint{epoch:04}.pth")
)
for checkpoint_path in checkpoint_paths:
with PathManager.open(checkpoint_path, "wb") as f:
if args.gpu == 0 and args.machine_rank == 0:
utils.save_on_master(
{
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
},
f,
)
test_stats, coco_evaluator = evaluate(
model,
criterion,
postprocessors,
data_loader_val,
base_ds,
device,
args.output_dir,
)
log_stats = {
**{f"train_{k}": v for k, v in train_stats.items()},
**{f"test_{k}": v for k, v in test_stats.items()},
"epoch": epoch,
"n_parameters": n_parameters,
}
if args.output_dir and utils.is_main_process():
with PathManager.open(os.path.join(args.output_dir, "log.txt"), "w") as f:
f.write(json.dumps(log_stats) + "\n")
# for evaluation logs
if coco_evaluator is not None:
PathManager.mkdirs(os.path.join(args.output_dir, "eval"))
if "bbox" in coco_evaluator.coco_eval:
filenames = ["latest.pth"]
if epoch % 50 == 0:
filenames.append(f"{epoch:03}.pth")
for name in filenames:
with PathManager.open(
os.path.join(args.output_dir, "eval", name), "wb"
) as f:
torch.save(coco_evaluator.coco_eval["bbox"].eval, f)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))