in pose_estimation/train.py [0:0]
def main():
args = parse_args()
reset_config(config, args)
logger, final_output_dir, tb_log_dir = create_logger(
config, args.cfg, 'train')
logger.info(pprint.pformat(args))
logger.info(pprint.pformat(config))
# cudnn related setting
cudnn.benchmark = config.CUDNN.BENCHMARK
torch.backends.cudnn.deterministic = config.CUDNN.DETERMINISTIC
torch.backends.cudnn.enabled = config.CUDNN.ENABLED
model = eval('models.'+config.MODEL.NAME+'.get_pose_net')(
config, is_train=True
)
# copy model file
this_dir = os.path.dirname(__file__)
shutil.copy2(
os.path.join(this_dir, '../lib/models', config.MODEL.NAME + '.py'),
final_output_dir)
writer_dict = {
'writer': SummaryWriter(log_dir=tb_log_dir),
'train_global_steps': 0,
'valid_global_steps': 0,
}
dump_input = torch.rand((config.TRAIN.BATCH_SIZE,
3,
config.MODEL.IMAGE_SIZE[1],
config.MODEL.IMAGE_SIZE[0]))
writer_dict['writer'].add_graph(model, (dump_input, ), verbose=False)
gpus = [int(i) for i in config.GPUS.split(',')]
model = torch.nn.DataParallel(model, device_ids=gpus).cuda()
# define loss function (criterion) and optimizer
criterion = JointsMSELoss(
use_target_weight=config.LOSS.USE_TARGET_WEIGHT
).cuda()
optimizer = get_optimizer(config, model)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, config.TRAIN.LR_STEP, config.TRAIN.LR_FACTOR
)
# Data loading code
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
train_dataset = eval('dataset.'+config.DATASET.DATASET)(
config,
config.DATASET.ROOT,
config.DATASET.TRAIN_SET,
True,
transforms.Compose([
transforms.ToTensor(),
normalize,
])
)
valid_dataset = eval('dataset.'+config.DATASET.DATASET)(
config,
config.DATASET.ROOT,
config.DATASET.TEST_SET,
False,
transforms.Compose([
transforms.ToTensor(),
normalize,
])
)
train_loader = torch.utils.data.DataLoader(
train_dataset,
batch_size=config.TRAIN.BATCH_SIZE*len(gpus),
shuffle=config.TRAIN.SHUFFLE,
num_workers=config.WORKERS,
pin_memory=True
)
valid_loader = torch.utils.data.DataLoader(
valid_dataset,
batch_size=config.TEST.BATCH_SIZE*len(gpus),
shuffle=False,
num_workers=config.WORKERS,
pin_memory=True
)
best_perf = 0.0
best_model = False
for epoch in range(config.TRAIN.BEGIN_EPOCH, config.TRAIN.END_EPOCH):
lr_scheduler.step()
# train for one epoch
train(config, train_loader, model, criterion, optimizer, epoch,
final_output_dir, tb_log_dir, writer_dict)
# evaluate on validation set
perf_indicator = validate(config, valid_loader, valid_dataset, model,
criterion, final_output_dir, tb_log_dir,
writer_dict)
if perf_indicator > best_perf:
best_perf = perf_indicator
best_model = True
else:
best_model = False
logger.info('=> saving checkpoint to {}'.format(final_output_dir))
save_checkpoint({
'epoch': epoch + 1,
'model': get_model_name(config),
'state_dict': model.state_dict(),
'perf': perf_indicator,
'optimizer': optimizer.state_dict(),
}, best_model, final_output_dir)
final_model_state_file = os.path.join(final_output_dir,
'final_state.pth.tar')
logger.info('saving final model state to {}'.format(
final_model_state_file))
torch.save(model.module.state_dict(), final_model_state_file)
writer_dict['writer'].close()