in main_stica.py [0:0]
def main():
# parse arguments
global args
parser = parse_arguments()
args = parser.parse_args()
# exp setup: logger, distributed mode and seeds
init_distributed_mode(args)
init_signal_handler()
fix_random_seeds(args.seed)
logger, training_stats = initialize_exp(args, "epoch", "loss")
if args.rank == 0:
writer = SummaryWriter(args.dump_path)
writer.add_text(
'args',
" \n".join(['%s : %s' % (arg, getattr(args, arg)) for arg in vars(args)]),
0
)
else:
writer = None
# Spec Augment params: []
if args.audio_augtype == 'mild':
aug_audio = [1, 1, 2, 5]
elif args.audio_augtype == 'medium':
aug_audio = [1, 1, 3, 6]
elif args.audio_augtype == 'heavy':
aug_audio = [2, 2, 3, 6]
else:
aug_audio = []
train_dataset = AVideoDataset(
ds_name=args.dataset_name,
mode='train',
root_dir=args.root_dir,
decode_audio=True,
args=args
)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
drop_last=True
)
logger.info("Building data done with {} images loaded.".format(
len(train_dataset)))
# build model
model = Stica_TransformerFMCrop(
vid_base_arch='r2plus1d_18',
aud_base_arch='resnet9',
pretrained=False,
norm_feat=True,
use_mlp=True,
num_classes=256, # embedding dimension
args=args
)
# synchronize batch norm layers
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
# copy model to GPU
model = model.cuda()
if args.rank == 0:
logger.info(model)
logger.info("Building model done.")
# build optimizer
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.base_lr,
momentum=0.9,
weight_decay=args.wd,
)
if args.use_warmup_scheduler:
warmup_lr_schedule = np.linspace(
args.start_warmup, args.base_lr, len(train_loader) * args.warmup_epochs)
iters = np.arange(len(train_loader) * (args.epochs - args.warmup_epochs))
if args.use_lr_scheduler:
cosine_lr_schedule = np.array(
[args.final_lr + 0.5 * (args.base_lr - args.final_lr) * (1 + \
math.cos(math.pi * t / (len(train_loader) * (args.epochs - args.warmup_epochs))))
for t in iters
])
lr_schedule = np.concatenate((warmup_lr_schedule, cosine_lr_schedule))
else:
constant_schedule = np.array([args.base_lr for t in iters])
lr_schedule = np.concatenate((warmup_lr_schedule, constant_schedule))
logger.info("Building optimizer done.")
# init mixed precision
if args.use_fp16:
model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
logger.info("Initializing mixed precision done.")
# wrap model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[args.gpu_to_work_on],
find_unused_parameters=True,
)
# optionally resume from a checkpoint
to_restore = {"epoch": 0}
restart_from_checkpoint(
os.path.join(args.dump_path, "checkpoint.pth.tar"),
run_variables=to_restore,
state_dict=model,
optimizer=optimizer,
amp=apex.amp if args.use_fp16 else None,
)
start_epoch = to_restore["epoch"]
# Set CuDNN benhcmark
cudnn.benchmark = True
for epoch in range(start_epoch, args.epochs):
# train the network for one epoch
logger.info("============ Starting epoch %i ... ============" % epoch)
# set sampler
train_loader.sampler.set_epoch(epoch)
# train the network
scores = train(
train_loader, model, optimizer, epoch, lr_schedule, writer)
training_stats.update(scores)
if args.rank == 0 and writer:
writer.add_scalar('pretrain/epoch', epoch, epoch)
# save checkpoints
if args.rank == 0:
save_dict = {
"epoch": epoch + 1,
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
if args.use_fp16:
save_dict["amp"] = apex.amp.state_dict()
torch.save(
save_dict,
os.path.join(
args.dump_path,
"checkpoint.pth.tar"
),
)
if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
shutil.copyfile(
os.path.join(
args.dump_path,
"checkpoint.pth.tar"
),
os.path.join(
args.dump_checkpoints,
"ckp-" + str(epoch) + ".pth"
),
)