in func/train.py [0:0]
def main(cfg):
logger = logging.getLogger(__name__)
dist_info, device, writer = initial_setup(cfg, logger)
# Data loading code
logger.info("Loading data")
logger.info("\t Loading datasets")
st = time.time()
# separate these into get transforms
# TODO: This is gotten too complex: clean up, make interface better
transform_train = [
T.ToTensorVideo(),
T.Resize(_get_resize_shape(cfg.data_train)),
T.RandomHorizontalFlipVideo(cfg.data_train.flip_p),
T.ColorJitterVideo(brightness=cfg.data_train.color_jitter_brightness,
contrast=cfg.data_train.color_jitter_contrast,
saturation=cfg.data_train.color_jitter_saturation,
hue=cfg.data_train.color_jitter_hue),
torchvision.transforms.Lambda(
lambda x: x * cfg.data_train.scale_pix_val),
torchvision.transforms.Lambda(lambda x: x[[2, 1, 0], ...])
if cfg.data_train.reverse_channels else torchvision.transforms.Compose(
[]),
T.NormalizeVideo(**_get_pixel_mean_std(cfg.data_train)),
]
if cfg.data_train.crop_size is not None:
transform_train.append(
T.RandomCropVideo(
(cfg.data_train.crop_size, cfg.data_train.crop_size)), )
transform_train = torchvision.transforms.Compose(transform_train)
transform_eval = [
T.ToTensorVideo(),
T.Resize(_get_resize_shape(cfg.data_eval)),
torchvision.transforms.Lambda(
lambda x: x * cfg.data_eval.scale_pix_val),
torchvision.transforms.Lambda(lambda x: x[[2, 1, 0], ...]) if
cfg.data_eval.reverse_channels else torchvision.transforms.Compose([]),
T.NormalizeVideo(**_get_pixel_mean_std(cfg.data_eval)),
]
if cfg.data_eval.crop_size is not None:
transform_eval.append(
T.MultiCropVideo(
(cfg.data_eval.crop_size, cfg.data_eval.crop_size),
cfg.data_eval.eval_num_crops, cfg.data_eval.eval_flip_crops))
transform_eval = torchvision.transforms.Compose(transform_eval)
datasets_train = [
get_dataset(getattr(cfg, el), cfg.data_train, transform_train, logger)
for el in cfg.keys() if el.startswith(DATASET_TRAIN_CFG_KEY)
]
if len(datasets_train) > 1:
dataset = torch.utils.data.ConcatDataset(datasets_train)
else:
dataset = datasets_train[0]
# could be multiple test datasets
datasets_test = {
el[len(DATASET_EVAL_CFG_KEY):]:
get_dataset(getattr(cfg, el), cfg.data_eval, transform_eval, logger)
for el in cfg.keys() if el.startswith(DATASET_EVAL_CFG_KEY)
}
logger.info("Took %d", time.time() - st)
logger.info("Creating data loaders")
train_sampler = None
test_samplers = {key: None for key in datasets_test}
if hasattr(dataset, 'video_clips'):
assert cfg.train.shuffle_data, 'TODO'
train_sampler = RandomClipSampler(getattr(dataset, 'video_clips'),
cfg.data_train.train_bs_multiplier)
test_samplers = {
key: UniformClipSampler(val.video_clips,
cfg.data_eval.val_clips_per_video)
for key, val in datasets_test.items()
}
if dist_info['distributed']:
train_sampler = DistributedSampler(train_sampler)
test_samplers = [DistributedSampler(el) for el in test_samplers]
elif dist_info['distributed']:
# Distributed, but doesn't have video_clips
if cfg.data_train.use_dist_sampler:
train_sampler = torch.utils.data.distributed.DistributedSampler(
dataset,
num_replicas=dist_info['world_size'],
rank=dist_info['rank'],
shuffle=cfg.train.shuffle_data)
if cfg.data_eval.use_dist_sampler:
test_samplers = {
key: torch.utils.data.distributed.DistributedSampler(
val,
num_replicas=dist_info['world_size'],
rank=dist_info['rank'],
shuffle=False)
for key, val in datasets_test.items()
}
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=cfg.train.batch_size,
sampler=train_sampler,
num_workers=cfg.data_train.workers,
pin_memory=False, # usually hurts..
shuffle=(train_sampler is None and cfg.train.shuffle_data),
collate_fn=collate_fn_remove_audio,
)
data_loaders_test = {
key: torch.utils.data.DataLoader(
val,
# Since no backprop, so can have a larger batch size
batch_size=cfg.eval.batch_size or cfg.train.batch_size * 4,
sampler=test_samplers[key],
num_workers=cfg.data_eval.workers,
pin_memory=False, # Usually hurts..
shuffle=False,
collate_fn=collate_fn_remove_audio,
)
for key, val in datasets_test.items()
}
num_classes = {key: len(val) for key, val in dataset.classes.items()}
logger.info('Creating model with %s classes', num_classes)
model = base_model.BaseModel(cfg.model,
num_classes=num_classes,
class_mappings=dataset.class_mappings)
logger.debug('Model: %s', model)
if dist_info['distributed'] and cfg.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if cfg.train.init_from_model:
# This can have structure as follows:
# <module name>:<path to init model>;<module name>:<path>: ...
for module_ckpt in cfg.train.init_from_model:
elts = module_ckpt
if len(elts) == 1:
model_to_init = model
ckpt_modules_to_keep = None
ckpt_path = elts[0]
elif len(elts) == 2:
model_to_init = operator.attrgetter(elts[0])(model)
ckpt_modules_to_keep = None
ckpt_path = elts[1]
elif len(elts) == 3:
model_to_init = operator.attrgetter(elts[0])(model)
ckpt_modules_to_keep = elts[1]
ckpt_path = elts[2]
else:
raise ValueError(f'Incorrect formatting {module_ckpt}')
init_model(model_to_init, ckpt_path, ckpt_modules_to_keep, logger)
model.to(device)
if cfg.opt.classifier_only:
assert len(cfg.opt.lr_wd) == 1
assert cfg.opt.lr_wd[0][0] == 'classifier'
model = _set_all_bn_to_not_track_running_mean(model)
params = []
for this_module_names, this_lr, this_wd in cfg.opt.lr_wd:
if OmegaConf.get_type(this_module_names) != list:
this_module_names = [this_module_names]
this_modules = [
operator.attrgetter(el)(model) if el != '__all__' else model
for el in this_module_names
]
this_params_bias_bn = {}
this_params_rest = {}
for this_module_name, this_module in zip(this_module_names,
this_modules):
for name, param in this_module.named_parameters():
# ignore the param without grads
if not param.requires_grad:
continue
# May not always have a ".bias" if it's the last element, and no
# module name
if name.endswith('bias') or ('.bn' in name):
this_params_bias_bn[this_module_name + '.' + name] = param
else:
this_params_rest[this_module_name + '.' + name] = param
this_scaled_lr = this_lr * dist_info['world_size']
if cfg.opt.scale_lr_by_bs:
this_scaled_lr *= cfg.train.batch_size
params.append({
'params': this_params_rest.values(),
'lr': this_scaled_lr,
'weight_decay': this_wd,
})
logger.info('Using LR %f WD %f for parameters %s', params[-1]['lr'],
params[-1]['weight_decay'], this_params_rest.keys())
params.append({
'params': this_params_bias_bn.values(),
'lr': this_scaled_lr,
'weight_decay': this_wd * cfg.opt.bias_bn_wd_scale,
})
logger.info('Using LR %f WD %f for parameters %s', params[-1]['lr'],
params[-1]['weight_decay'], this_params_bias_bn.keys())
# Remove any parameters for which LR is 0; will save GPU usage
params_final = []
for param_lr in params:
if param_lr['lr'] != 0.0:
params_final.append(param_lr)
else:
for param in param_lr['params']:
param.requires_grad = False
optimizer = hydra.utils.instantiate(cfg.opt.optimizer, params_final)
# convert scheduler to be per iteration,
# not per epoch, for warmup that lasts
# between different epochs
main_scheduler = hydra.utils.instantiate(
cfg.opt.scheduler,
optimizer,
iters_per_epoch=len(data_loader),
world_size=dist_info['world_size'])
lr_scheduler = hydra.utils.instantiate(cfg.opt.warmup,
optimizer,
main_scheduler,
iters_per_epoch=len(data_loader),
world_size=dist_info['world_size'])
last_saved_ckpt = CKPT_FNAME
start_epoch = 0
if os.path.isfile(last_saved_ckpt):
checkpoint = torch.load(last_saved_ckpt, map_location='cpu')
model.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
start_epoch = checkpoint['epoch']
logger.warning('Loaded model from %s (ep %f)', last_saved_ckpt,
start_epoch)
if dist_info['distributed'] and not cfg.eval.eval_fn.only_run_featext:
# If only feat ext, then each gpu is going to test separately anyway,
# no need for communication between the models
logger.info('Wrapping model into DDP')
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[dist_info['gpu']],
output_device=dist_info['gpu'])
elif cfg.data_parallel:
logger.info('Wrapping model into DP')
device_ids = range(dist_info['world_size'])
model = torch.nn.parallel.DataParallel(model, device_ids=device_ids)
# TODO add an option here to support val mode training
# Passing in the training dataset, since that will be used for computing
# weights for classes etc.
train_eval_op = hydra.utils.instantiate(cfg.train_eval_op,
model,
device,
dataset,
_recursive_=False)
if cfg.test_only:
logger.info("Starting test_only")
hydra.utils.call(cfg.eval.eval_fn, train_eval_op, data_loaders_test,
writer, logger, start_epoch)
return
logger.info("Start training")
start_time = time.time()
# Get training metric logger
stat_loggers = get_default_loggers(writer, start_epoch, logger)
best_acc1 = 0.0
partial_epoch = start_epoch - int(start_epoch)
start_epoch = int(start_epoch)
last_saved_time = datetime.datetime(1, 1, 1, 0, 0)
epoch = 0 # Since using this var to write the checkpoint output, so init to sth
for epoch in range(start_epoch, cfg.train.num_epochs):
if dist_info['distributed'] and train_sampler is not None:
train_sampler.set_epoch(epoch)
last_saved_time = hydra.utils.call(cfg.train.train_one_epoch_fn,
train_eval_op, optimizer,
lr_scheduler, data_loader, epoch,
partial_epoch,
stat_loggers["train"], logger,
last_saved_time)
partial_epoch = 0 # Reset, for future epochs
store_checkpoint([CKPT_FNAME], model, optimizer, lr_scheduler,
epoch + 1)
if cfg.train.eval_freq and epoch % cfg.train.eval_freq == 0:
acc1 = hydra.utils.call(cfg.eval.eval_fn, train_eval_op,
data_loaders_test, writer, logger,
epoch + 1)
else:
acc1 = 0
if cfg.train.store_best and acc1 >= best_acc1:
store_checkpoint('checkpoint_best.pth', model, optimizer,
lr_scheduler, epoch + 1)
best_acc1 = acc1
if isinstance(lr_scheduler.base_scheduler,
scheduler.ReduceLROnPlateau):
lr_scheduler.step(acc1)
# reset all meters in the metric logger
for log in stat_loggers:
stat_loggers[log].reset_meters()
# Store the final model to checkpoint
store_checkpoint([CKPT_FNAME], model, optimizer, lr_scheduler, epoch + 1)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info('Training time %s', total_time_str)