in segmentation/tool/train.py [0:0]
def main_worker(gpu, ngpus_per_node, argss):
global args
args = argss
if args.distributed:
if args.dist_url == "env://" and args.rank == -1:
args.rank = int(os.environ["RANK"])
if args.multiprocessing_distributed:
args.rank = args.rank * ngpus_per_node + gpu
dist.init_process_group(backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank)
criterion = nn.CrossEntropyLoss(ignore_index=args.ignore_label)
if args.arch == 'psp':
from model.pspnet import PSPNet
model = PSPNet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, criterion=criterion)
modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
modules_new = [model.ppm, model.cls, model.aux]
elif args.arch == 'psa':
from model.psanet import PSANet
model = PSANet(layers=args.layers, classes=args.classes, zoom_factor=args.zoom_factor, psa_type=args.psa_type,
compact=args.compact, shrink_factor=args.shrink_factor, mask_h=args.mask_h, mask_w=args.mask_w,
normalization_factor=args.normalization_factor, psa_softmax=args.psa_softmax, criterion=criterion)
modules_ori = [model.layer0, model.layer1, model.layer2, model.layer3, model.layer4]
modules_new = [model.psa, model.cls, model.aux]
elif args.arch == 'fcn':
from model.fcn import FCNet
model = FCNet(layers=args.layers, classes=args.classes, criterion=criterion)
modules_ori = [model.model.backbone]
modules_new = [model.model.classifier, model.model.aux_classifier]
params_list = []
for module in modules_ori:
params_list.append(dict(params=module.parameters(), lr=args.base_lr))
for module in modules_new:
params_list.append(dict(params=module.parameters(), lr=args.base_lr * 10))
if args.arch == 'fcn':
args.index_split = 1
else:
args.index_split = 5
optimizer = torch.optim.SGD(params_list, lr=args.base_lr, momentum=args.momentum, weight_decay=args.weight_decay)
if args.sync_bn:
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
if main_process():
global logger, writer
logger = get_logger()
writer = SummaryWriter(args.save_path)
logger.info(args)
logger.info("=> creating model ...")
logger.info("Classes: {}".format(args.classes))
logger.info(model)
if args.distributed:
torch.cuda.set_device(gpu)
args.batch_size = int(args.batch_size / ngpus_per_node)
args.batch_size_val = int(args.batch_size_val / ngpus_per_node)
args.workers = int((args.workers + ngpus_per_node - 1) / ngpus_per_node)
model = torch.nn.parallel.DistributedDataParallel(model.cuda(), device_ids=[gpu])
else:
model = torch.nn.DataParallel(model.cuda())
if args.weight:
if os.path.isfile(args.weight):
if main_process():
logger.info("=> loading weight '{}'".format(args.weight))
checkpoint = torch.load(args.weight)
model.load_state_dict(checkpoint['state_dict'])
if main_process():
logger.info("=> loaded weight '{}'".format(args.weight))
else:
if main_process():
logger.info("=> no weight found at '{}'".format(args.weight))
if args.resume:
if os.path.isfile(args.resume):
if main_process():
logger.info("=> loading checkpoint '{}'".format(args.resume))
# checkpoint = torch.load(args.resume)
checkpoint = torch.load(args.resume, map_location=lambda storage, loc: storage.cuda())
args.start_epoch = checkpoint['epoch']
model.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
if main_process():
logger.info("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
else:
if main_process():
logger.info("=> no checkpoint found at '{}'".format(args.resume))
value_scale = 255
mean = [0.485, 0.456, 0.406]
mean = [item * value_scale for item in mean]
std = [0.229, 0.224, 0.225]
std = [item * value_scale for item in std]
train_transform = transform.Compose([
transform.RandScale([args.scale_min, args.scale_max]),
transform.RandRotate([args.rotate_min, args.rotate_max], padding=mean, ignore_label=args.ignore_label),
transform.RandomGaussianBlur(),
transform.RandomHorizontalFlip(),
transform.Crop([args.train_h, args.train_w], crop_type='rand', padding=mean, ignore_label=args.ignore_label),
transform.ToTensor(),
transform.Normalize(mean=mean, std=std)])
train_data = dataset.SemData(split='train', data_root=args.data_root, data_list=args.train_list, transform=train_transform, dataset=args.dataset)
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(train_data)
else:
train_sampler = None
train_loader = torch.utils.data.DataLoader(train_data, batch_size=args.batch_size, shuffle=(train_sampler is None), num_workers=args.workers, pin_memory=True, sampler=train_sampler, drop_last=True)
if args.evaluate:
val_transform = transform.Compose([
transform.Crop([args.train_h, args.train_w], crop_type='center', padding=mean, ignore_label=args.ignore_label),
transform.ToTensor(),
transform.Normalize(mean=mean, std=std)])
val_data = dataset.SemData(split='val', data_root=args.data_root, data_list=args.val_list, transform=val_transform, dataset=args.dataset)
if args.distributed:
val_sampler = torch.utils.data.distributed.DistributedSampler(val_data)
else:
val_sampler = None
val_loader = torch.utils.data.DataLoader(val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)
if args.dataset == 'gtav':
if args.evaluate_on_cityscapes:
cityscapes_val_data = dataset.SemData(split='val', data_root='dataset/cityscapes/', data_list='dataset/cityscapes/val.txt', transform=val_transform, dataset='cityscapes')
cityscapes_val_loader = torch.utils.data.DataLoader(cityscapes_val_data, batch_size=args.batch_size_val, shuffle=False, num_workers=args.workers, pin_memory=True, sampler=val_sampler)
for epoch in range(args.start_epoch, args.epochs):
epoch_log = epoch + 1
if args.distributed:
train_sampler.set_epoch(epoch)
loss_train, mIoU_train, mAcc_train, allAcc_train = train(train_loader, model, optimizer, epoch)
if main_process():
writer.add_scalar('loss_train', loss_train, epoch_log)
writer.add_scalar('mIoU_train', mIoU_train, epoch_log)
writer.add_scalar('mAcc_train', mAcc_train, epoch_log)
writer.add_scalar('allAcc_train', allAcc_train, epoch_log)
if (epoch_log % args.save_freq == 0) and main_process():
filename = args.save_path + '/train_epoch_' + str(epoch_log) + '.pth'
logger.info('Saving checkpoint to: ' + filename)
torch.save({'epoch': epoch_log, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict()}, filename)
if epoch_log / args.save_freq > 2:
deletename = args.save_path + '/train_epoch_' + str(epoch_log - args.save_freq * 2) + '.pth'
os.remove(deletename)
if args.evaluate:
if epoch_log % args.eval_freq == 0:
loss_val, mIoU_val, mAcc_val, allAcc_val = validate(val_loader, model, criterion)
if main_process():
writer.add_scalar('loss_val', loss_val, epoch_log)
writer.add_scalar('mIoU_val', mIoU_val, epoch_log)
writer.add_scalar('mAcc_val', mAcc_val, epoch_log)
writer.add_scalar('allAcc_val', allAcc_val, epoch_log)
if args.dataset == 'gtav':
if args.evaluate_on_cityscapes:
loss_val_city, mIoU_val_city, mAcc_val_city, allAcc_val_city = validate(cityscapes_val_loader, model, criterion)
if main_process():
writer.add_scalar('loss_val_city', loss_val_city, epoch_log)
writer.add_scalar('mIoU_val_city', mIoU_val_city, epoch_log)
writer.add_scalar('mAcc_val_city', mAcc_val_city, epoch_log)
writer.add_scalar('allAcc_val_city', allAcc_val_city, epoch_log)