in pt/vmz/func/train.py [0:0]
def train_main(args):
torchvision.set_video_backend("video_reader")
if args.apex:
if sys.version_info < (3, 0):
raise RuntimeError("Apex currently only supports Python 3. Aborting.")
if amp is None:
raise RuntimeError(
"Failed to import apex. Please install apex "
"from https://www.github.com/nvidia/apex "
"to enable mixed-precision training."
)
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
print("torch version: ", torch.__version__)
print("torchvision version: ", torchvision.__version__)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
writer = setup_tbx(args.output_dir)
# Data loading code
print("Loading data")
print("\t Loading datasets")
st = time.time()
if not args.eval_only:
print("\t Loading train data")
transform_train = torchvision.transforms.Compose(
[
T.ToTensorVideo(),
T.Resize((args.scale_h, args.scale_w)),
T.RandomHorizontalFlipVideo(),
T.NormalizeVideo(
mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)
),
T.RandomCropVideo((args.crop_size, args.crop_size)),
]
)
dataset = get_dataset(args, transform_train)
dataset.video_clips.compute_clips(args.num_frames, 1, frame_rate=15)
train_sampler = RandomClipSampler(dataset.video_clips, args.train_bs_multiplier)
if args.distributed:
train_sampler = DistributedSampler(train_sampler)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
)
print("\t Loading validation data")
transform_test = torchvision.transforms.Compose(
[
T.ToTensorVideo(),
T.Resize((args.scale_h, args.scale_w)),
T.NormalizeVideo(
mean=(0.43216, 0.394666, 0.37645), std=(0.22803, 0.22145, 0.216989)
),
T.CenterCropVideo((args.crop_size, args.crop_size)),
]
)
dataset_test = get_dataset(args, transform_test, split="val")
dataset_test.video_clips.compute_clips(args.num_frames, 1, frame_rate=15)
test_sampler = UniformClipSampler(
dataset_test.video_clips, args.val_clips_per_video
)
if args.distributed:
test_sampler = DistributedSampler(test_sampler)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
sampler=test_sampler,
num_workers=args.workers,
)
criterion = nn.CrossEntropyLoss()
print("Creating model")
# TODO: model only from our models
available_models = {**models.__dict__}
model = available_models[args.model](pretraining=args.pretrained)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
if args.resume_from_model and not args.resume:
checkpoint = torch.load(args.resume_from_model, map_location="cpu")
if "model" in checkpoint.keys():
model.load_state_dict(checkpoint["model"])
else:
model.load_state_dict(checkpoint)
if args.finetune:
assert args.resume_from_model is not None or args.pretrained
model.fc = nn.Linear(model.fc.in_features, args.num_finetune_classes)
lr = args.lr * args.world_size
if args.finetune:
params = [
{"params": model.stem.parameters(), "lr": 0},
{"params": model.layer1.parameters(), "lr": args.l1_lr * args.world_size},
{"params": model.layer2.parameters(), "lr": args.l2_lr * args.world_size},
{"params": model.layer3.parameters(), "lr": args.l3_lr * args.world_size},
{"params": model.layer4.parameters(), "lr": args.l4_lr * args.world_size},
{"params": model.fc.parameters(), "lr": args.fc_lr * args.world_size},
]
else:
params = model.parameters()
print(params)
optimizer = torch.optim.SGD(
params, lr=lr, momentum=args.momentum, weight_decay=args.weight_decay,
)
if args.apex:
model, optimizer = amp.initialize(
model, optimizer, opt_level=args.apex_opt_level
)
# convert scheduler to be per iteration,
# not per epoch, for warmup that lasts
# between different epochs
if not args.eval_only:
warmup_iters = args.lr_warmup_epochs * len(data_loader)
lr_milestones = [len(data_loader) * m for m in args.lr_milestones]
lr_scheduler = WarmupMultiStepLR(
optimizer,
milestones=lr_milestones,
gamma=args.lr_gamma,
warmup_iters=warmup_iters,
warmup_factor=1e-5,
)
if os.path.isfile(os.path.join(args.output_dir, "checkpoint.pth")):
args.resume = os.path.join(args.output_dir, "checkpoint.pth")
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.eval_only:
print("Starting test_only")
metric_logger = MetricLogger(delimiter=" ", writer=writer, stat_set="val")
evaluate(model, criterion, data_loader_test, device, metric_logger)
return
# Get training metric logger
stat_loggers = get_default_loggers(writer, args.start_epoch)
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(
model,
criterion,
optimizer,
lr_scheduler,
data_loader,
device,
epoch,
args.print_freq,
stat_loggers["train"],
args.apex,
)
evaluate(model, criterion, data_loader_test, device, stat_loggers["val"])
if args.output_dir:
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
utils.save_on_master(
checkpoint, os.path.join(args.output_dir, "model_{}.pth".format(epoch))
)
utils.save_on_master(
checkpoint, os.path.join(args.output_dir, "checkpoint.pth")
)
# reset all meters in the metric logger
for log in stat_loggers:
stat_loggers[log].reset_meters()
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print("Training time {}".format(total_time_str))