in eval_video.py [0:0]
def main(args, writer):
# Create Logger
logger, training_stats = initialize_exp(
args, "epoch", "loss", "prec1", "prec5", "loss_val", "prec1_val", "prec5_val"
)
# Set CudNN benchmark
torch.backends.cudnn.benchmark = True
# Load model
logger.info("Loading model")
model = load_model(
model_type=args.model,
vid_base_arch=args.vid_base_arch,
aud_base_arch=args.aud_base_arch,
pretrained=args.pretrained,
norm_feat=False,
use_mlp=args.use_mlp,
num_classes=256,
args=args,
)
# Load model weights
weight_path_type = type(args.weights_path)
if weight_path_type == str:
weight_path_not_none = args.weights_path != 'None'
else:
weight_path_not_none = args.weights_path is not None
if not args.pretrained and weight_path_not_none:
logger.info("Loading model weights")
if os.path.exists(args.weights_path):
ckpt_dict = torch.load(args.weights_path)
try:
model_weights = ckpt_dict["state_dict"]
except:
model_weights = ckpt_dict["model"]
epoch = ckpt_dict["epoch"]
logger.info(f"Epoch checkpoint: {epoch}")
load_model_parameters(model, model_weights)
logger.info(f"Loading model done")
# Add FC layer to model for fine-tuning or feature extracting
model = load_model_finetune(
args,
model.video_network.base,
pooling_arch=model.video_pooling if args.agg_model else None,
num_ftrs=model.encoder_dim,
num_classes=NUM_CLASSES[args.dataset],
use_dropout=args.use_dropout,
use_bn=args.use_bn,
use_l2_norm=args.use_l2_norm,
dropout=0.9,
agg_model=args.agg_model,
)
# Create DataParallel model
model = model.cuda()
model = torch.nn.DataParallel(model)
model_without_ddp = model.module
# Get params for optimization
params = []
if args.feature_extract: # feature_extract only classifer
logger.info("Getting params for feature-extracting")
for name, param in model_without_ddp.classifier.named_parameters():
logger.info((name, param.shape))
params.append(
{
'params': param,
'lr': args.head_lr,
'weight_decay': args.weight_decay
})
else: # finetune
logger.info("Getting params for finetuning")
for name, param in model_without_ddp.classifier.named_parameters():
logger.info((name, param.shape))
params.append(
{
'params': param,
'lr': args.head_lr,
'weight_decay': args.weight_decay
})
for name, param in model_without_ddp.base.named_parameters():
logger.info((name, param.shape))
params.append(
{
'params': param,
'lr': args.base_lr,
'weight_decay': args.wd_base
})
if args.agg_model:
logger.info("Adding pooling arch params to be optimized")
for name, param in model_without_ddp.pooling_arch.named_parameters():
if param.requires_grad and param.dim() >= 1:
logger.info(f"Adding {name}({param.shape}), wd: {args.wd_tsf}")
params.append(
{
'params': param,
'lr': args.tsf_lr,
'weight_decay': args.wd_tsf
})
else:
logger.info(f"Not adding {name} to be optimized")
logger.info('\n===========Check Grad============')
for name, param in model_without_ddp.named_parameters():
logger.info((name, param.requires_grad))
logger.info('=================================\n')
logger.info("Creating AV Datasets")
dataset = AVideoDataset(
ds_name=args.dataset,
root_dir=args.root_dir,
mode='train',
num_train_clips=args.train_clips_per_video,
decode_audio=False,
center_crop=False,
fold=args.fold,
ucf101_annotation_path=args.ucf101_annotation_path,
hmdb51_annotation_path=args.hmdb51_annotation_path,
args=args,
)
dataset_test = AVideoDataset(
ds_name=args.dataset,
root_dir=args.root_dir,
mode='test',
decode_audio=False,
num_spatial_crops=args.num_spatial_crops,
num_ensemble_views=args.val_clips_per_video,
ucf101_annotation_path=args.ucf101_annotation_path,
hmdb51_annotation_path=args.hmdb51_annotation_path,
fold=args.fold,
args=args,
)
# Creating dataloaders
logger.info("Creating data loaders")
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
shuffle=True
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
drop_last=False
)
# linearly scale LR and set up optimizer
logger.info(f"Using SGD with lr: {args.head_lr}, wd: {args.weight_decay}")
optimizer = torch.optim.SGD(
params,
lr=args.head_lr,
momentum=args.momentum,
weight_decay=args.weight_decay
)
# Multi-step LR scheduler
if args.use_scheduler:
milestones = [int(lr) - args.lr_warmup_epochs for lr in args.lr_milestones.split(',')]
logger.info(f"Num. of Epochs: {args.epochs}, Milestones: {milestones}")
if args.lr_warmup_epochs > 0:
logger.info(f"Using scheduler with {args.lr_warmup_epochs} warmup epochs")
scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=args.lr_gamma
)
lr_scheduler = GradualWarmupScheduler(
optimizer,
multiplier=8,
total_epoch=args.lr_warmup_epochs,
after_scheduler=scheduler_step
)
else: # no warmp, just multi-step
logger.info("Using scheduler w/out warmup")
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=args.lr_gamma
)
else:
lr_scheduler = None
# Checkpointing
if args.resume:
ckpt_path = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth')
checkpoint = torch.load(ckpt_path, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch']
logger.info(f"Resuming from epoch: {args.start_epoch}")
# Only perform evalaution
if args.test_only:
scores_val = evaluate(
model,
data_loader_test,
epoch=args.start_epoch,
writer=writer,
ds=args.dataset,
)
_, vid_acc1, vid_acc5 = scores_val
return vid_acc1, vid_acc5, args.start_epoch
start_time = time.time()
best_vid_acc_1 = -1
best_vid_acc_5 = -1
best_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
logger.info(f'Start training epoch: {epoch}')
scores = train(
model,
optimizer,
data_loader,
epoch,
writer=writer,
ds=args.dataset,
)
logger.info(f'Start evaluating epoch: {epoch}')
lr_scheduler.step()
if (epoch % 1 == 0) and epoch > 6:
scores_val = evaluate(
model,
data_loader_test,
epoch=epoch,
writer=writer,
ds=args.dataset,
)
_, vid_acc1, vid_acc5 = scores_val
training_stats.update(scores + scores_val)
if vid_acc1 > best_vid_acc_1:
best_vid_acc_1 = vid_acc1
best_vid_acc_5 = vid_acc5
best_epoch = epoch
if args.output_dir:
logger.info(f'Saving checkpoint to: {args.output_dir}')
save_checkpoint(args, epoch, model, optimizer, lr_scheduler, ckpt_freq=1)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f'Training time {total_time_str}')
return best_vid_acc_1, best_vid_acc_5, best_epoch