in finetune_video.py [0:0]
def main(args, writer):
# Create Logger
logger, training_stats = initialize_exp(
args, "epoch", "loss", "prec1", "prec5",
"loss_val", "prec1_val", "prec5_val"
)
# Set CudNN benchmark
torch.backends.cudnn.benchmark = True
# Load model
logger.info("Loading model")
model = load_model(
vid_base_arch=args.vid_base_arch,
aud_base_arch=args.aud_base_arch,
pretrained=args.pretrained,
num_classes=args.num_clusters,
norm_feat=False,
use_mlp=args.use_mlp,
headcount=args.headcount,
)
# Load model weights
weight_path_type = type(args.weights_path)
if weight_path_type == str:
weight_path_not_none = args.weights_path != 'None'
else:
weight_path_not_none = args.weights_path is not None
if not args.pretrained and weight_path_not_none:
logger.info("Loading model weights")
if os.path.exists(args.weights_path):
ckpt_dict = torch.load(args.weights_path)
model_weights = ckpt_dict["model"]
logger.info(f"Epoch checkpoint: {args.ckpt_epoch}")
load_model_parameters(model, model_weights)
logger.info(f"Loading model done")
# Add FC layer to model for fine-tuning or feature extracting
model = Finetune_Model(
model.video_network.base,
get_video_dim(vid_base_arch=args.vid_base_arch),
NUM_CLASSES[args.dataset],
use_dropout=args.use_dropout,
use_bn=args.use_bn,
use_l2_norm=args.use_l2_norm,
dropout=0.7
)
# Create DataParallel model
model = model.cuda()
model = torch.nn.DataParallel(model)
model_without_ddp = model.module
# Get params for optimization
params = []
if args.feature_extract: # feature_extract only classifer
for name, param in model_without_ddp.classifier.named_parameters():
logger.info((name, param.shape))
params.append(
{'params': param,
'lr': args.head_lr,
'weight_decay': args.weight_decay
})
else: # finetune
for name, param in model_without_ddp.classifier.named_parameters():
logger.info((name, param.shape))
params.append(
{'params': param,
'lr': args.head_lr,
'weight_decay': args.weight_decay
})
for name, param in model_without_ddp.base.named_parameters():
logger.info((name, param.shape))
params.append(
{'params': param,
'lr': args.base_lr,
'weight_decay': args.wd_base
})
logger.info("Creating AV Datasets")
dataset = AVideoDataset(
ds_name=args.dataset,
root_dir=args.root_dir,
mode='train',
num_frames=args.clip_len,
sample_rate=args.steps_bet_clips,
num_train_clips=args.train_clips_per_video,
train_crop_size=128 if args.augtype == 1 else 224,
seed=None,
fold=args.fold,
colorjitter=args.colorjitter,
temp_jitter=True,
center_crop=False,
target_fps=30,
decode_audio=False,
)
dataset_test = AVideoDataset(
ds_name=args.dataset,
root_dir=args.root_dir,
mode='test',
num_frames=args.clip_len,
sample_rate=args.steps_bet_clips,
test_crop_size=128 if args.augtype == 1 else 224,
num_spatial_crops=args.num_spatial_crops,
num_ensemble_views=args.val_clips_per_video,
seed=None,
fold=args.fold,
colorjitter=args.test_time_cj,
temp_jitter=True,
target_fps=30,
decode_audio=False,
)
# Creating dataloaders
logger.info("Creating data loaders")
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=None,
num_workers=args.workers,
pin_memory=True,
drop_last=True,
shuffle=True
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
sampler=None,
num_workers=args.workers,
pin_memory=True,
drop_last=False
)
# linearly scale LR and set up optimizer
if args.optim_name == 'sgd':
optimizer = torch.optim.SGD(
params,
lr=args.head_lr,
momentum=args.momentum,
weight_decay=args.weight_decay
)
elif args.optim_name == 'adam':
optimizer = torch.optim.Adam(
params,
lr=args.head_lr,
weight_decay=args.weight_decay
)
# Multi-step LR scheduler
if args.use_scheduler:
lr_milestones = args.lr_milestones.split(',')
milestones = [int(lr) - args.lr_warmup_epochs for lr in lr_milestones]
if args.lr_warmup_epochs > 0:
scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=args.lr_gamma
)
multiplier = 8
lr_scheduler = GradualWarmupScheduler(
optimizer,
multiplier=multiplier,
total_epoch=args.lr_warmup_epochs,
after_scheduler=scheduler_step
)
else: # no warmp, just multi-step
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=args.lr_gamma
)
else:
lr_scheduler = None
# Checkpointing
if args.resume:
ckpt_path = os.path.join(
args.output_dir, 'checkpoints', 'checkpoint.pth')
checkpoint = torch.load(ckpt_path, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
if lr_scheduler is not None:
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch']
logger.info(f"Resuming from epoch: {args.start_epoch}")
# Only perform evalaution
if args.test_only:
scores_val = evaluate(
model,
data_loader_test,
epoch=args.start_epoch,
writer=writer,
ds=args.dataset,
)
_, vid_acc1, vid_acc5 = scores_val
return vid_acc1, vid_acc5, args.start_epoch
start_time = time.time()
best_vid_acc_1 = -1
best_vid_acc_5 = -1
best_epoch = 0
for epoch in range(args.start_epoch, args.epochs):
logger.info(f'Start training epoch: {epoch}')
scores = train(
model,
optimizer,
data_loader,
epoch,
writer=writer,
ds=args.dataset,
)
logger.info(f'Start evaluating epoch: {epoch}')
lr_scheduler.step()
scores_val = evaluate(
model,
data_loader_test,
epoch=epoch,
writer=writer,
ds=args.dataset,
)
_, vid_acc1, vid_acc5 = scores_val
training_stats.update(scores + scores_val)
if vid_acc1 > best_vid_acc_1:
best_vid_acc_1 = vid_acc1
best_vid_acc_5 = vid_acc5
best_epoch = epoch
if args.output_dir:
logger.info(f'Saving checkpoint to: {args.output_dir}')
save_checkpoint(args, epoch, model, optimizer, lr_scheduler,
ckpt_freq=1)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f'Training time {total_time_str}')
return best_vid_acc_1, best_vid_acc_5, best_epoch