torchbenchmark/models/Super_SloMo/train.py (87 lines of code) (raw):

#[Super SloMo] ##High Quality Estimation of Multiple Intermediate Frames for Video Interpolation import argparse import torch import torchvision import torchvision.transforms as transforms import torch.optim as optim import torch.nn as nn import torch.nn.functional as F import slomo_model as model from model_wrapper import Model import dataloader from math import log10 import datetime from tensorboardX import SummaryWriter import random random.seed(1337) torch.manual_seed(1337) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False # For parsing commandline arguments parser = argparse.ArgumentParser() parser.add_argument("--dataset_root", type=str, required=True, help='path to dataset folder containing train-test-validation folders') parser.add_argument("--checkpoint_dir", type=str, required=True, help='path to folder for saving checkpoints') parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model') parser.add_argument("--epochs", type=int, default=200, help='number of epochs to train. Default: 200.') parser.add_argument("--train_batch_size", type=int, default=6, help='batch size for training. Default: 6.') parser.add_argument("--init_learning_rate", type=float, default=0.0001, help='set initial learning rate. Default: 0.0001.') parser.add_argument("--milestones", type=list, default=[100, 150], help='Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]') parser.add_argument("--checkpoint_epoch", type=int, default=5, help='checkpoint saving frequency. N: after every N epochs. Each checkpoint is roughly of size 151 MB.Default: 5.') parser.add_argument("--debug", type=str, default=None, help='dump model output') parser.add_argument("--trace", action='store_true', default=False, help='trace model') parser.add_argument("--script", action='store_true', default=False, help='script model') args = parser.parse_args() ##[TensorboardX](https://github.com/lanpa/tensorboardX) ### For visualizing loss and interpolated frames writer = SummaryWriter('log') ###Initialize flow computation and arbitrary-time flow interpolation CNNs. assert torch.cuda.is_available() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") ###Load Datasets # Channel wise mean calculated on adobe240-fps training dataset mean = [0.429, 0.431, 0.397] std = [1, 1, 1] normalize = transforms.Normalize(mean=mean, std=std) transform = transforms.Compose([transforms.ToTensor(), normalize]) trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train', transform=transform, train=True) trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=False) print(trainset) ###Create transform to display image from tensor negmean = [x * -1 for x in mean] revNormalize = transforms.Normalize(mean=negmean, std=std) TP = transforms.Compose([revNormalize, transforms.ToPILImage()]) ###Utils def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr'] ###Model, Loss and Optimizer the_model = Model(device) optimizer = optim.Adam(the_model.parameters(), lr=args.init_learning_rate) # scheduler to decrease learning rate by a factor of 10 at milestones. scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1) ### Initialization dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1} ### Training import time start = time.time() cLoss = dict1['loss'] valLoss = dict1['valLoss'] valPSNR = dict1['valPSNR'] checkpoint_counter = 0 if args.trace: for trainData, trainFrameIndex in trainloader: frame0, frameT, frame1 = trainData I0 = frame0.to(device) I1 = frame1.to(device) IFrame = frameT.to(device) the_model = torch.jit.trace(the_model, example_inputs=(trainFrameIndex, I0, I1, IFrame)) break if args.script: the_model = torch.jit.script(the_model) ### Main training loop for epoch in range(dict1['epoch'] + 1, args.epochs): print("Epoch: ", epoch) # Append and reset cLoss.append([]) valLoss.append([]) valPSNR.append([]) iLoss = 0 # Increment scheduler count scheduler.step() for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0): ## Getting the input and the target from the training set frame0, frameT, frame1 = trainData I0 = frame0.to(device) I1 = frame1.to(device) IFrame = frameT.to(device) optimizer.zero_grad() Ft_p, loss = the_model(trainFrameIndex, I0, I1, IFrame) if args.debug: torch.save(Ft_p, args.debug) # Backpropagate loss.backward() optimizer.step() iLoss += loss.item()