in main.py [0:0]
def main():
# parse arguments
global args
parser = parse_arguments()
args = parser.parse_args()
# exp setup: logger, distributed mode and seeds
init_distributed_mode(args)
init_signal_handler()
fix_random_seeds(args.seed)
logger, training_stats = initialize_exp(args, "epoch", "loss")
if args.rank == 0:
writer = SummaryWriter(args.dump_path)
else:
writer = None
# build data
train_dataset = AVideoDataset(
ds_name=args.ds_name,
root_dir=args.root_dir,
mode='train',
path_to_data_dir=args.data_path,
num_frames=args.num_frames,
target_fps=args.target_fps,
sample_rate=args.sample_rate,
num_train_clips=args.num_train_clips,
train_crop_size=args.train_crop_size,
test_crop_size=args.test_crop_size,
num_data_samples=args.num_data_samples,
colorjitter=args.colorjitter,
use_grayscale=args.use_grayscale,
use_gaussian=args.use_gaussian,
temp_jitter=True,
decode_audio=True,
aug_audio=None,
num_sec=args.num_sec_aud,
aud_sample_rate=args.aud_sample_rate,
aud_spec_type=args.aud_spec_type,
use_volume_jittering=args.use_volume_jittering,
use_temporal_jittering=args.use_audio_temp_jittering,
z_normalize=args.z_normalize,
dual_data=args.dual_data
)
sampler = torch.utils.data.distributed.DistributedSampler(train_dataset)
train_loader = torch.utils.data.DataLoader(
train_dataset,
sampler=sampler,
batch_size=args.batch_size,
num_workers=args.workers,
pin_memory=True,
drop_last=True
)
logger.info("Loaded data with {} videos.".format(len(train_dataset)))
# Load model
model = load_model(
vid_base_arch=args.vid_base_arch,
aud_base_arch=args.aud_base_arch,
use_mlp=args.use_mlp,
num_classes=args.mlp_dim,
pretrained=False,
norm_feat=False,
use_max_pool=False,
headcount=args.headcount,
)
# synchronize batch norm layers
if args.sync_bn == "pytorch":
model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
elif args.sync_bn == "apex":
process_group = None
if args.world_size // 8 > 0:
process_group = apex.parallel.create_syncbn_process_group(args.world_size // 8)
model = apex.parallel.convert_syncbn_model(model, process_group=process_group)
# copy model to GPU
model = model.cuda()
if args.rank == 0:
logger.info(model)
logger.info("Building model done.")
# build optimizer
optimizer = torch.optim.SGD(
model.parameters(),
lr=args.base_lr,
momentum=0.9,
weight_decay=args.wd,
)
if args.use_warmup_scheduler:
lr_scheduler = GradualWarmupScheduler(
optimizer,
multiplier=args.world_size,
total_epoch=args.warmup_epochs,
after_scheduler=None
)
else:
lr_scheduler = None
logger.info("Building optimizer done.")
# init mixed precision
if args.use_fp16:
model, optimizer = apex.amp.initialize(model, optimizer, opt_level="O1")
logger.info("Initializing mixed precision done.")
# wrap model
model = nn.parallel.DistributedDataParallel(
model,
device_ids=[args.gpu_to_work_on],
find_unused_parameters=True,
)
# SK-Init
N_dl = len(train_loader)
N = len(train_loader.dataset)
N_distr = N_dl * train_loader.batch_size
selflabels = torch.zeros((N, args.headcount), dtype=torch.long, device='cuda')
global sk_schedule
sk_schedule = (args.epochs * N_dl * (np.linspace(0, 1, args.nopts) ** args.schedulepower)[::-1]).tolist()
# to make sure we don't make it empty
sk_schedule = [(args.epochs + 2) * N_dl] + sk_schedule
logger.info(f'remaining SK opts @ epochs {[np.round(1.0 * t / N_dl, 2) for t in sk_schedule]}')
# optionally resume from a checkpoint
to_restore = {"epoch": 0, 'selflabels': selflabels, 'dist':args.dist}
restart_from_checkpoint(
os.path.join(args.dump_path, "checkpoint.pth.tar"),
run_variables=to_restore,
model=model,
optimizer=optimizer,
amp=apex.amp if args.use_fp16 else None,
)
start_epoch = to_restore["epoch"]
selflabels = to_restore["selflabels"]
args.dist = to_restore["dist"]
# Set CuDNN benhcmark
cudnn.benchmark = True
# Restart schedule correctly
if start_epoch != 0:
include = [(qq / N_dl > start_epoch) for qq in sk_schedule]
# (total number of sk-opts) - (number of sk-opts outstanding)
global sk_counter
sk_counter = len(sk_schedule) - sum(include)
sk_schedule = (np.array(sk_schedule)[include]).tolist()
if lr_scheduler:
[lr_scheduler.step() for _ in range(to_restore['epoch'])]
if start_epoch == 0:
train_loader.sampler.set_epoch(999)
warmup_batchnorm(args, model, train_loader, batches=20, group=group)
for epoch in range(start_epoch, args.epochs):
# train the network for one epoch
logger.info("============ Starting epoch %i ... ============" % epoch)
if writer:
writer.add_scalar('train/epoch', epoch, epoch)
# set sampler
train_loader.sampler.set_epoch(epoch)
# train the network
scores, selflabels = train(
train_loader, model, optimizer, epoch, writer, selflabels)
training_stats.update(scores)
# Update LR scheduler
if lr_scheduler:
lr_scheduler.step()
# save checkpoints
if args.rank == 0:
save_dict = {
"epoch": epoch + 1,
"dist": args.dist,
"model": model.state_dict(),
"optimizer": optimizer.state_dict(),
"selflabels": selflabels
}
if args.use_fp16:
save_dict["amp"] = apex.amp.state_dict()
torch.save(
save_dict,
os.path.join(args.dump_path, "checkpoint.pth.tar"),
)
if epoch % args.checkpoint_freq == 0 or epoch == args.epochs - 1:
shutil.copyfile(
os.path.join(args.dump_path, "checkpoint.pth.tar"),
os.path.join(args.dump_checkpoints, "ckp-" + str(epoch) + ".pth")
)