torchbenchmark/models/Super_SloMo/train.py (87 lines of code) (raw):
#[Super SloMo]
##High Quality Estimation of Multiple Intermediate Frames for Video Interpolation
import argparse
import torch
import torchvision
import torchvision.transforms as transforms
import torch.optim as optim
import torch.nn as nn
import torch.nn.functional as F
import slomo_model as model
from model_wrapper import Model
import dataloader
from math import log10
import datetime
from tensorboardX import SummaryWriter
import random
random.seed(1337)
torch.manual_seed(1337)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# For parsing commandline arguments
parser = argparse.ArgumentParser()
parser.add_argument("--dataset_root", type=str, required=True, help='path to dataset folder containing train-test-validation folders')
parser.add_argument("--checkpoint_dir", type=str, required=True, help='path to folder for saving checkpoints')
parser.add_argument("--checkpoint", type=str, help='path of checkpoint for pretrained model')
parser.add_argument("--epochs", type=int, default=200, help='number of epochs to train. Default: 200.')
parser.add_argument("--train_batch_size", type=int, default=6, help='batch size for training. Default: 6.')
parser.add_argument("--init_learning_rate", type=float, default=0.0001, help='set initial learning rate. Default: 0.0001.')
parser.add_argument("--milestones", type=list, default=[100, 150], help='Set to epoch values where you want to decrease learning rate by a factor of 0.1. Default: [100, 150]')
parser.add_argument("--checkpoint_epoch", type=int, default=5, help='checkpoint saving frequency. N: after every N epochs. Each checkpoint is roughly of size 151 MB.Default: 5.')
parser.add_argument("--debug", type=str, default=None, help='dump model output')
parser.add_argument("--trace", action='store_true', default=False, help='trace model')
parser.add_argument("--script", action='store_true', default=False, help='script model')
args = parser.parse_args()
##[TensorboardX](https://github.com/lanpa/tensorboardX)
### For visualizing loss and interpolated frames
writer = SummaryWriter('log')
###Initialize flow computation and arbitrary-time flow interpolation CNNs.
assert torch.cuda.is_available()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
###Load Datasets
# Channel wise mean calculated on adobe240-fps training dataset
mean = [0.429, 0.431, 0.397]
std = [1, 1, 1]
normalize = transforms.Normalize(mean=mean,
std=std)
transform = transforms.Compose([transforms.ToTensor(), normalize])
trainset = dataloader.SuperSloMo(root=args.dataset_root + '/train', transform=transform, train=True)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=args.train_batch_size, shuffle=False)
print(trainset)
###Create transform to display image from tensor
negmean = [x * -1 for x in mean]
revNormalize = transforms.Normalize(mean=negmean, std=std)
TP = transforms.Compose([revNormalize, transforms.ToPILImage()])
###Utils
def get_lr(optimizer):
for param_group in optimizer.param_groups:
return param_group['lr']
###Model, Loss and Optimizer
the_model = Model(device)
optimizer = optim.Adam(the_model.parameters(), lr=args.init_learning_rate)
# scheduler to decrease learning rate by a factor of 10 at milestones.
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.milestones, gamma=0.1)
### Initialization
dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}
### Training
import time
start = time.time()
cLoss = dict1['loss']
valLoss = dict1['valLoss']
valPSNR = dict1['valPSNR']
checkpoint_counter = 0
if args.trace:
for trainData, trainFrameIndex in trainloader:
frame0, frameT, frame1 = trainData
I0 = frame0.to(device)
I1 = frame1.to(device)
IFrame = frameT.to(device)
the_model = torch.jit.trace(the_model, example_inputs=(trainFrameIndex, I0, I1, IFrame))
break
if args.script:
the_model = torch.jit.script(the_model)
### Main training loop
for epoch in range(dict1['epoch'] + 1, args.epochs):
print("Epoch: ", epoch)
# Append and reset
cLoss.append([])
valLoss.append([])
valPSNR.append([])
iLoss = 0
# Increment scheduler count
scheduler.step()
for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0):
## Getting the input and the target from the training set
frame0, frameT, frame1 = trainData
I0 = frame0.to(device)
I1 = frame1.to(device)
IFrame = frameT.to(device)
optimizer.zero_grad()
Ft_p, loss = the_model(trainFrameIndex, I0, I1, IFrame)
if args.debug:
torch.save(Ft_p, args.debug)
# Backpropagate
loss.backward()
optimizer.step()
iLoss += loss.item()