in Segmentation/train.py [0:0]
def main(args):
args.log_dir = save_path_formatter(args)
if args.deconv:
args.deconv = partial(Deconv, bias=args.bias, eps=args.eps, n_iter=args.deconv_iter,block=args.block,sampling_stride=args.stride,sync=args.sync,norm_type=args.norm_type)
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
args.ngpus_per_node = torch.cuda.device_count()
print(args)
if args.tensorboard and utils.is_main_process():
from torch.utils.tensorboard import SummaryWriter
args.writer = SummaryWriter(args.log_dir,flush_secs=30)
device = torch.device(args.device)
transform=get_transform(mode='train',base_size=args.base_size)
dataset, num_classes = get_dataset(args.dataset, "train", transform=transform)
transform=get_transform(mode='test',base_size=args.base_size)
dataset_test, _ = get_dataset(args.dataset, "val", transform=transform)
args.colormap=create_mapillary_vistas_label_colormap()
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
else:
train_sampler = torch.utils.data.RandomSampler(dataset)
test_sampler = torch.utils.data.SequentialSampler(dataset_test)
data_loader = torch.utils.data.DataLoader(
dataset, batch_size=args.batch_size,
sampler=train_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn, drop_last=True)
data_loader_test = torch.utils.data.DataLoader(
dataset_test, batch_size=args.batch_size,
sampler=test_sampler, num_workers=args.workers,
collate_fn=utils.collate_fn)
#model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes,aux_loss=args.aux_loss,pretrained=args.pretrained)
model = models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained,deconv=args.deconv,pretrained_backbone=args.pretrained_backbone)
model.to(device)
if args.distributed:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.resume:
checkpoint = torch.load(args.resume, map_location='cpu')
model.load_state_dict(checkpoint['model'])
args.start_epoch = checkpoint['epoch']
print("=> loaded checkpoint '{}' (epoch {})"
.format(args.resume, checkpoint['epoch']))
del checkpoint
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
elif args.device=='cuda':
model = torch.nn.DataParallel(model).cuda()
print(model)
if args.test_only:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
print(confmat)
return
if args.pretrained_backbone and args.deconv:#
params_to_optimize = [
{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad], "lr": args.lr}, #--> args.lr*0.1 can give potentially better results
{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad], "lr": args.lr},
]
else:
params_to_optimize = [
{"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
{"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
]
if args.aux_loss:
params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
params_to_optimize.append({"params": params, "lr": args.lr * 10})
if args.optimizer=='SGD':
optimizer = torch.optim.SGD(
params_to_optimize,
lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
elif args.optimizer=='Adam':
optimizer = torch.optim.Adam(params_to_optimize, lr=args.lr, weight_decay=args.weight_decay)
else:
print('optimizer error')
total_steps = len(data_loader)*args.epochs
if args.lr_scheduler == 'cosine':
lr_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=0, last_epoch=-1)
else:
lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / total_steps) ** 0.9)
if args.resume:
total_steps = len(data_loader)* args.start_epoch
global n_iter
for i in range(total_steps):
n_iter = n_iter + 1
lr_scheduler.step()
start_time = time.time()
for epoch in range(args.start_epoch,args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch,args.print_freq)
if epoch==0 or (epoch+1)%args.eval_freq==0:
confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
utils.save_on_master(
{
'model': model_without_ddp.state_dict(),
'optimizer': optimizer.state_dict(),
'epoch': epoch,
#'args': args
},
#os.path.join(args.log_dir, 'model_{}.pth'.format(epoch)))
os.path.join(args.log_dir, 'model.pth'))
print(confmat)
acc_global, acc, iu =confmat .compute()
acc_global=acc_global.item() * 100
iu=iu.mean().item() * 100
if args.tensorboard and utils.is_main_process():
args.writer.add_scalar('Acc/Test',acc_global,epoch+1)
args.writer.add_scalar('IOU/Test',iu,epoch+1)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print('Training time {}'.format(total_time_str))
if args.tensorboard and utils.is_main_process():
args.writer.close()