in references/video_classification/train.py [0:0]
def main(args):
if args.prototype and prototype is None:
raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
if not args.prototype and args.weights:
raise ValueError("The weights parameter works only in prototype mode. Please pass the --prototype argument.")
if args.output_dir:
utils.mkdir(args.output_dir)
utils.init_distributed_mode(args)
print(args)
print("torch version: ", torch.__version__)
print("torchvision version: ", torchvision.__version__)
device = torch.device(args.device)
torch.backends.cudnn.benchmark = True
# Data loading code
print("Loading data")
traindir = os.path.join(args.data_path, args.train_dir)
valdir = os.path.join(args.data_path, args.val_dir)
print("Loading training data")
st = time.time()
cache_path = _get_cache_path(traindir)
transform_train = presets.VideoClassificationPresetTrain((128, 171), (112, 112))
if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_train from {cache_path}")
dataset, _ = torch.load(cache_path)
dataset.transform = transform_train
else:
if args.distributed:
print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
dataset = torchvision.datasets.Kinetics400(
traindir,
frames_per_clip=args.clip_len,
step_between_clips=1,
transform=transform_train,
frame_rate=15,
extensions=(
"avi",
"mp4",
),
)
if args.cache_dataset:
print(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset, traindir), cache_path)
print("Took", time.time() - st)
print("Loading validation data")
cache_path = _get_cache_path(valdir)
if not args.prototype:
transform_test = presets.VideoClassificationPresetEval(resize_size=(128, 171), crop_size=(112, 112))
else:
if args.weights:
weights = prototype.models.get_weight(args.weights)
transform_test = weights.transforms()
else:
transform_test = prototype.transforms.Kinect400Eval(crop_size=(112, 112), resize_size=(128, 171))
if args.cache_dataset and os.path.exists(cache_path):
print(f"Loading dataset_test from {cache_path}")
dataset_test, _ = torch.load(cache_path)
dataset_test.transform = transform_test
else:
if args.distributed:
print("It is recommended to pre-compute the dataset cache on a single-gpu first, as it will be faster")
dataset_test = torchvision.datasets.Kinetics400(
valdir,
frames_per_clip=args.clip_len,
step_between_clips=1,
transform=transform_test,
frame_rate=15,
extensions=(
"avi",
"mp4",
),
)
if args.cache_dataset:
print(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
utils.save_on_master((dataset_test, valdir), cache_path)
print("Creating data loaders")
train_sampler = RandomClipSampler(dataset.video_clips, args.clips_per_video)
test_sampler = UniformClipSampler(dataset_test.video_clips, args.clips_per_video)
if args.distributed:
train_sampler = DistributedSampler(train_sampler)
test_sampler = DistributedSampler(test_sampler)
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=args.batch_size,
sampler=train_sampler,
num_workers=args.workers,
pin_memory=True,
collate_fn=collate_fn,
)
data_loader_test = torch.utils.data.DataLoader(
dataset_test,
batch_size=args.batch_size,
sampler=test_sampler,
num_workers=args.workers,
pin_memory=True,
collate_fn=collate_fn,
)
print("Creating model")
if not args.prototype:
model = torchvision.models.video.__dict__[args.model](pretrained=args.pretrained)
else:
model = prototype.models.video.__dict__[args.model](weights=args.weights)
model.to(device)
if args.distributed and args.sync_bn:
model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
criterion = nn.CrossEntropyLoss()
lr = args.lr * args.world_size
optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=args.momentum, weight_decay=args.weight_decay)
scaler = torch.cuda.amp.GradScaler() if args.amp else None
# convert scheduler to be per iteration, not per epoch, for warmup that lasts
# between different epochs
iters_per_epoch = len(data_loader)
lr_milestones = [iters_per_epoch * (m - args.lr_warmup_epochs) for m in args.lr_milestones]
main_lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=lr_milestones, gamma=args.lr_gamma)
if args.lr_warmup_epochs > 0:
warmup_iters = iters_per_epoch * args.lr_warmup_epochs
args.lr_warmup_method = args.lr_warmup_method.lower()
if args.lr_warmup_method == "linear":
warmup_lr_scheduler = torch.optim.lr_scheduler.LinearLR(
optimizer, start_factor=args.lr_warmup_decay, total_iters=warmup_iters
)
elif args.lr_warmup_method == "constant":
warmup_lr_scheduler = torch.optim.lr_scheduler.ConstantLR(
optimizer, factor=args.lr_warmup_decay, total_iters=warmup_iters
)
else:
raise RuntimeError(
f"Invalid warmup lr method '{args.lr_warmup_method}'. Only linear and constant are supported."
)
lr_scheduler = torch.optim.lr_scheduler.SequentialLR(
optimizer, schedulers=[warmup_lr_scheduler, main_lr_scheduler], milestones=[warmup_iters]
)
else:
lr_scheduler = main_lr_scheduler
model_without_ddp = model
if args.distributed:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
model_without_ddp = model.module
if args.resume:
checkpoint = torch.load(args.resume, map_location="cpu")
model_without_ddp.load_state_dict(checkpoint["model"])
optimizer.load_state_dict(checkpoint["optimizer"])
lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
args.start_epoch = checkpoint["epoch"] + 1
if args.amp:
scaler.load_state_dict(checkpoint["scaler"])
if args.test_only:
evaluate(model, criterion, data_loader_test, device=device)
return
print("Start training")
start_time = time.time()
for epoch in range(args.start_epoch, args.epochs):
if args.distributed:
train_sampler.set_epoch(epoch)
train_one_epoch(model, criterion, optimizer, lr_scheduler, data_loader, device, epoch, args.print_freq, scaler)
evaluate(model, criterion, data_loader_test, device=device)
if args.output_dir:
checkpoint = {
"model": model_without_ddp.state_dict(),
"optimizer": optimizer.state_dict(),
"lr_scheduler": lr_scheduler.state_dict(),
"epoch": epoch,
"args": args,
}
if args.amp:
checkpoint["scaler"] = scaler.state_dict()
utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
total_time = time.time() - start_time
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
print(f"Training time {total_time_str}")