in main_gdt.py [0:0]
def main(args):
# Set up mixed precision training
if args.apex:
if sys.version_info < (3, 0):
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
if amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex from https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)
# Make output dir
if args.output_dir:
makedir(args.output_dir)
# Init distributed mode
if torch.cuda.is_available():
init_distributed_mode(args)
# init signal handler
init_signal_handler()
# Set up logger
if args.distributed:
filename = str(args.job_id) + '_' + str(args.rank) + '_log.out'
# Set up tensorboard
tbx_path = os.path.join(args.output_dir, 'tensorboard')
global_rank = args.rank if args.distributed else 0
is_master = True if global_rank == 0 else False
if is_master:
writer = SummaryWriter(tbx_path)
writer.add_text(
'args',
" \n".join(['%s : %s' % (arg, getattr(args, arg)) for arg in vars(args)]),
0
)
else:
writer = None
# Log version information
logger.info(args)
logger.info(f"torch version: {torch.__version__}")
# Set distributed mode
device = torch.device(args.device)
# Set CudNN benchmark
torch.backends.cudnn.benchmark = True
# Create model
logger.info("Creating model")
if args.model == 'av_gdt':
model = GDT(
vid_base_arch=args.vid_base_arch,
aud_base_arch=args.aud_base_arch,
pretrained=False,
norm_feat=args.norm_feat,
use_mlp=args.use_mlp,
num_classes=256,
)
else:
# Video-Text GDT encoder for pretraining
model = TextVid_GDT(
vid_base_arch=args.vid_base_arch,
text_base_arch='word2vec',
pretrained=False,
norm_feat=args.norm_feat,
use_mlp=args.use_mlp,
num_classes=256,
)
model.to(device)
if args.distributed and args.sync_bn:
logger.info("Sync BN on model")
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
model_without_ddp = model
if args.distributed:
ngpus_per_node = torch.cuda.device_count()
model = torch.nn.parallel.DistributedDataParallel(
model,
device_ids=[args.local_rank],
output_device=args.local_rank,
broadcast_buffers=False
)
model_without_ddp = model.module
if args.aug_audio:
if args.audio_augtype == 'mild':
args.aug_audio = [1, 1, 2, 5]
elif args.audio_augtype == 'medium':
args.aug_audio = [1, 1, 3, 6]
elif args.audio_augtype == 'heavy':
args.aug_audio = [2, 2, 3, 6]
# Set up training optimizer
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.lr,
momentum=args.momentum,
weight_decay=args.weight_decay
)
# For Mixed Precision training
if args.apex:
model, optimizer = amp.initialize(
model,
optimizer,
opt_level=args.apex_opt_level
)
# Set up LR scheduler
milestones = [int(lr) - args.lr_warmup_epochs for lr in args.lr_milestones.split(',')]
lr_scheduler = None
if args.use_scheduler:
if args.lr_warmup_epochs > 0:
if args.scheduler_type == 'multi_step':
logger.info(f'Using Multi-Step LR scheduler')
scheduler_step = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=args.lr_gamma
)
else:
logger.info(f'Using Cosine Annealing LR scheduler')
scheduler_step = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
lr_scheduler = GradualWarmupScheduler(
optimizer,
multiplier=args.world_size,
total_epoch=args.lr_warmup_epochs,
after_scheduler=scheduler_step
)
else:
if args.scheduler_type == 'multi_step':
logger.info(f'Using Multi-Step LR scheduler')
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer,
milestones=milestones,
gamma=args.lr_gamma
)
else:
logger.info(f'Using Cosine Annealing LR scheduler')
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.epochs)
# Checkpointing restart
ckp_path = os.path.join(args.output_dir, 'checkpoints', 'checkpoint.pth')
if os.path.isfile(ckp_path):
logger.info(f'Loading checkpoint')
checkpoint = torch.load(ckp_path, map_location='cpu')
model_without_ddp.load_state_dict(checkpoint['model'])
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
args.start_epoch = checkpoint['epoch']
logger.info(f'Restrating at epoch {args.start_epoch}')
# Create dataloader
if args.dataset == "ht100m":
ds = HT100M_Dataset(
csv_file='data/howto.csv',
video_root=args.root_dir,
caption_root=args.ht100m_caption_root,
token_to_word_path='data/dict.npy',
fps=32/int(args.sample_rate),
num_frames=args.clip_len,
size=args.train_crop_size,
center_crop=args.center_crop, # True
)
else:
# Audio-Visual datasets: Kinetics-400/600, Audioset, VGG-Sound
ds = GDTPretrainDataset(
ds_name=args.dataset,
root_dir=args.root_dir,
mode='train',
args=args
)
print("Creating data loaders", flush=True)
train_sampler = None
if args.distributed:
train_sampler = torch.utils.data.distributed.DistributedSampler(ds)
data_loader = torch.utils.data.DataLoader(
ds,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
collate_fn=None,
drop_last=True
)
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
if writer:
writer.add_scalar('train/epoch', epoch, epoch)
logger.info(f'Start training epoch: {epoch}')
loss = train_one_epoch(
args,
data_loader,
model,
optimizer,
device,
epoch,
args.print_freq,
lr_scheduler,
args.apex,
writer=writer,
)
if lr_scheduler:
lr_scheduler.step()
if args.output_dir:
save_checkpoint(args, epoch, model, optimizer, lr_scheduler)
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
logger.info(f'Training time {total_time_str}')