def main()

in Segmentation/train.py [0:0]


def main(args):

    args.log_dir = save_path_formatter(args)
    if args.deconv:
        args.deconv = partial(Deconv, bias=args.bias, eps=args.eps, n_iter=args.deconv_iter,block=args.block,sampling_stride=args.stride,sync=args.sync,norm_type=args.norm_type)


    if args.output_dir:
        utils.mkdir(args.output_dir)

    utils.init_distributed_mode(args)
    args.ngpus_per_node = torch.cuda.device_count()

    print(args)

    if args.tensorboard and utils.is_main_process():
        from torch.utils.tensorboard import SummaryWriter
        args.writer = SummaryWriter(args.log_dir,flush_secs=30)


    device = torch.device(args.device)
    transform=get_transform(mode='train',base_size=args.base_size)
    dataset, num_classes = get_dataset(args.dataset, "train", transform=transform)

    transform=get_transform(mode='test',base_size=args.base_size)
    dataset_test, _ = get_dataset(args.dataset, "val", transform=transform)
    



    args.colormap=create_mapillary_vistas_label_colormap()

    if args.distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
        test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
    else:
        train_sampler = torch.utils.data.RandomSampler(dataset)
        test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=args.batch_size,
        sampler=train_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn, drop_last=True)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=args.batch_size,
        sampler=test_sampler, num_workers=args.workers,
        collate_fn=utils.collate_fn)

    #model = torchvision.models.segmentation.__dict__[args.model](num_classes=num_classes,aux_loss=args.aux_loss,pretrained=args.pretrained)
    model = models.segmentation.__dict__[args.model](num_classes=num_classes, aux_loss=args.aux_loss, pretrained=args.pretrained,deconv=args.deconv,pretrained_backbone=args.pretrained_backbone)

    model.to(device)


    if args.distributed:
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)

    if args.resume:
        checkpoint = torch.load(args.resume, map_location='cpu')
        model.load_state_dict(checkpoint['model'])
        args.start_epoch = checkpoint['epoch']
        print("=> loaded checkpoint '{}' (epoch {})"
                  .format(args.resume, checkpoint['epoch']))
        del checkpoint

    model_without_ddp = model
    if args.distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
        model_without_ddp = model.module
    elif args.device=='cuda':
        model = torch.nn.DataParallel(model).cuda()

    print(model)
    if args.test_only:
        confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)
        print(confmat)
        return

    if  args.pretrained_backbone and args.deconv:#
    
        params_to_optimize = [
            {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad], "lr": args.lr}, #--> args.lr*0.1 can give potentially better results
            {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad], "lr": args.lr}, 
        ]
    
    else:
        params_to_optimize = [
            {"params": [p for p in model_without_ddp.backbone.parameters() if p.requires_grad]},
            {"params": [p for p in model_without_ddp.classifier.parameters() if p.requires_grad]},
        ]
    if args.aux_loss:
        params = [p for p in model_without_ddp.aux_classifier.parameters() if p.requires_grad]
        params_to_optimize.append({"params": params, "lr": args.lr * 10})
    
    if args.optimizer=='SGD':
        optimizer = torch.optim.SGD(
            params_to_optimize,
            lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
    elif args.optimizer=='Adam':
        optimizer = torch.optim.Adam(params_to_optimize, lr=args.lr, weight_decay=args.weight_decay)
    else:
        print('optimizer error')
        
    total_steps = len(data_loader)*args.epochs

    if args.lr_scheduler == 'cosine':
        lr_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, total_steps, eta_min=0, last_epoch=-1)
    else:
        lr_scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lambda x: (1 - x / total_steps) ** 0.9)



    if args.resume:
        total_steps = len(data_loader)* args.start_epoch
        global n_iter
        for i in range(total_steps):
            n_iter = n_iter + 1            
            lr_scheduler.step()

    start_time = time.time()
    for epoch in range(args.start_epoch,args.epochs):
        if args.distributed:
            train_sampler.set_epoch(epoch)
        train_one_epoch(model, criterion, optimizer, data_loader, lr_scheduler, device, epoch,args.print_freq)
        
        if epoch==0 or (epoch+1)%args.eval_freq==0:
            confmat = evaluate(model, data_loader_test, device=device, num_classes=num_classes)

            utils.save_on_master(
                {
                    'model': model_without_ddp.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'epoch': epoch,
                    #'args': args
                },
                #os.path.join(args.log_dir, 'model_{}.pth'.format(epoch)))
                os.path.join(args.log_dir, 'model.pth'))
            
            print(confmat)

            acc_global, acc, iu =confmat .compute()
            acc_global=acc_global.item() * 100
            iu=iu.mean().item() * 100

            if args.tensorboard and utils.is_main_process():
                args.writer.add_scalar('Acc/Test',acc_global,epoch+1)
                args.writer.add_scalar('IOU/Test',iu,epoch+1)

    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    print('Training time {}'.format(total_time_str))
    if args.tensorboard and utils.is_main_process():
        args.writer.close()