torchbenchmark/models/Super_SloMo/train.ipynb (689 lines of code) (raw):
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "f_KNv25DX7B6"
},
"source": [
"#[Super SloMo](https://people.cs.umass.edu/~hzjiang/projects/superslomo/)\n",
"##High Quality Estimation of Multiple Intermediate Frames for Video Interpolation\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "0VWuBGh6zMMZ"
},
"outputs": [],
"source": [
"%matplotlib inline\n",
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"import torch.optim as optim\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import slomo_model as model\n",
"import dataloader\n",
"import matplotlib.pyplot as plt\n",
"from math import log10\n",
"from IPython.display import clear_output, display\n",
"import datetime\n",
"from tensorboardX import SummaryWriter"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1VynXmoKp_3M"
},
"source": [
"##Parameters"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "N2yrOVZjqDe9"
},
"outputs": [],
"source": [
"# Learning Rate. Set `MILESTONES` to epoch values where you want to decrease\n",
"# learning rate by a factor of 0.1\n",
"INITIAL_LEARNING_RATE = 0.0001\n",
"MILESTONES = [100, 150]\n",
"\n",
"# Number of epochs to train\n",
"EPOCHS = 200\n",
"\n",
"# Choose batchsize as per GPU/CPU configuration\n",
"# This configuration works on GTX 1080 Ti\n",
"TRAIN_BATCH_SIZE = 6\n",
"VALIDATION_BATCH_SIZE = 10\n",
"\n",
"# Path to dataset folder containing train-test-validation folders\n",
"DATASET_ROOT = \"path/to/dataset\"\n",
"\n",
"# Path to folder for saving checkpoints\n",
"CHECKPOINT_DIR = 'path/to/checkpoint_directory'\n",
"\n",
"# If resuming from checkpoint, set `trainingContinue` to True and set `checkpoint_path`\n",
"TRAINING_CONTINUE = False\n",
"CHECKPOINT_PATH = 'path/to/checkpoint/file'\n",
"\n",
"# Progress and validation frequency (N: after every N iterations)\n",
"PROGRESS_ITER = 100\n",
"\n",
"# Checkpoint frequency (N: after every N epochs). Each checkpoint is roughly of size 151 MB.\n",
"CHECKPOINT_EPOCH = 5"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Yr3Lm1ovbWv1"
},
"source": [
"##[TensorboardX](https://github.com/lanpa/tensorboardX)\n",
"### For visualizing loss and interpolated frames"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "saUJTMiMCAzH"
},
"outputs": [],
"source": [
"writer = SummaryWriter('log')"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Ua1DJm82aj5-"
},
"source": [
"###Initialize flow computation and arbitrary-time flow interpolation CNNs."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "D42vzEKrWtpG"
},
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"flowComp = model.UNet(6, 4)\n",
"flowComp.to(device)\n",
"ArbTimeFlowIntrp = model.UNet(20, 5)\n",
"ArbTimeFlowIntrp.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "UYMpk2EYchaY"
},
"source": [
"###Initialze backward warpers for train and validation datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "vJq6SrWIf2GE"
},
"outputs": [],
"source": [
"trainFlowBackWarp = model.backWarp(352, 352, device)\n",
"trainFlowBackWarp = trainFlowBackWarp.to(device)\n",
"validationFlowBackWarp = model.backWarp(640, 352, device)\n",
"validationFlowBackWarp = validationFlowBackWarp.to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "oSs9UaIjdTT2"
},
"source": [
"###Load Datasets"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "MJ9cVigEgtyT"
},
"outputs": [],
"source": [
"# Channel wise mean calculated on adobe240-fps training dataset\n",
"mean = [0.429, 0.431, 0.397]\n",
"std = [1, 1, 1]\n",
"normalize = transforms.Normalize(mean=mean,\n",
" std=std)\n",
"transform = transforms.Compose([transforms.ToTensor(), normalize])\n",
"\n",
"trainset = dataloader.SuperSloMo(root=DATASET_ROOT + '/train', transform=transform, train=True)\n",
"trainloader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)\n",
"\n",
"validationset = dataloader.SuperSloMo(root=DATASET_ROOT + '/validation', transform=transform, randomCropSize=(640, 352), train=False)\n",
"validationloader = torch.utils.data.DataLoader(validationset, batch_size=VALIDATION_BATCH_SIZE, shuffle=False)\n",
"\n",
"print(trainset, validationset)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "WXmNMdbJfp2d"
},
"source": [
"###Create transform to display image from tensor"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "try3adPHgwse"
},
"outputs": [],
"source": [
"negmean = [x * -1 for x in mean]\n",
"revNormalize = transforms.Normalize(mean=negmean, std=std)\n",
"TP = transforms.Compose([revNormalize, transforms.ToPILImage()])"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "32XZg9Mfd5bN"
},
"source": [
"###Test the dataset"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "0Vyf7dbwCO1E"
},
"outputs": [],
"source": [
"for trainIndex, (trainData, frameIndex) in enumerate(trainloader, 0):\n",
" frame0, frameT, frame1 = trainData\n",
" print(\"Intermediate frame index: \", (frameIndex[0]))\n",
" plt.imshow(TP(frame0[0]))\n",
" plt.grid(True)\n",
" plt.figure()\n",
" plt.imshow(TP(frameT[0]))\n",
" plt.grid(True)\n",
" plt.figure()\n",
" plt.imshow(TP(frame1[0]))\n",
" plt.grid(True)\n",
" break"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "rh0MK2qKuBlV"
},
"source": [
"###Utils"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "BdMFU0ijfIuI"
},
"outputs": [],
"source": [
"plt.rcParams['figure.figsize'] = [15, 3]\n",
"def Plot(num, listInp, d):\n",
" a = listInp\n",
" c = []\n",
" for b in a:\n",
" c.append(sum(b)/len(b))\n",
" plt.subplot(1, 2, num)\n",
" plt.plot(c, color=d)\n",
" plt.grid(True)\n",
" \n",
"def get_lr(optimizer):\n",
" for param_group in optimizer.param_groups:\n",
" return param_group['lr']"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "mooLcmxtpPR_"
},
"source": [
"###Loss and Optimizer"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "BuWQfcb-jhWx"
},
"outputs": [],
"source": [
"L1_lossFn = nn.L1Loss()\n",
"MSE_LossFn = nn.MSELoss()\n",
"\n",
"params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())\n",
"\n",
"optimizer = optim.Adam(params, lr=INITIAL_LEARNING_RATE)\n",
"# scheduler to decrease learning rate by a factor of 10 at milestones.\n",
"scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=0.1)"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "a5rIkwwfpk1n"
},
"source": [
"###Initializing VGG16 model for perceptual loss"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "9WR_NxHP51oB"
},
"outputs": [],
"source": [
"vgg16 = torchvision.models.vgg16(pretrained=True)\n",
"vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])\n",
"vgg16_conv_4_3.to(device)\n",
"for param in vgg16_conv_4_3.parameters():\n",
"\t\tparam.requires_grad = False"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "9-6wLaBJZqsm"
},
"source": [
"### Validation function\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "RhMMZ_I4iDFf"
},
"outputs": [],
"source": [
"def validate():\n",
" # For details see training.\n",
" psnr = 0\n",
" tloss = 0\n",
" flag = 1\n",
" with torch.no_grad():\n",
" for validationIndex, (validationData, validationFrameIndex) in enumerate(validationloader, 0):\n",
" frame0, frameT, frame1 = validationData\n",
"\n",
" I0 = frame0.to(device)\n",
" I1 = frame1.to(device)\n",
" IFrame = frameT.to(device)\n",
" \n",
" \n",
" flowOut = flowComp(torch.cat((I0, I1), dim=1))\n",
" F_0_1 = flowOut[:,:2,:,:]\n",
" F_1_0 = flowOut[:,2:,:,:]\n",
"\n",
" fCoeff = model.getFlowCoeff(validationFrameIndex, device)\n",
"\n",
" F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0\n",
" F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0\n",
"\n",
" g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)\n",
" g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)\n",
" \n",
" intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))\n",
" \n",
" F_t_0_f = intrpOut[:, :2, :, :] + F_t_0\n",
" F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1\n",
" V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :])\n",
" V_t_1 = 1 - V_t_0\n",
" \n",
" g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)\n",
" g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)\n",
" \n",
" wCoeff = model.getWarpCoeff(validationFrameIndex, device)\n",
" \n",
" Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)\n",
" \n",
" # For tensorboard\n",
" if (flag):\n",
" retImg = torchvision.utils.make_grid([revNormalize(frame0[0]), revNormalize(frameT[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1[0])], padding=10)\n",
" flag = 0\n",
" \n",
" \n",
" #loss\n",
" recnLoss = L1_lossFn(Ft_p, IFrame)\n",
" \n",
" prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))\n",
" \n",
" warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(validationFlowBackWarp(I1, F_0_1), I0)\n",
" \n",
" loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))\n",
" loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))\n",
" loss_smooth = loss_smooth_1_0 + loss_smooth_0_1\n",
" \n",
" \n",
" loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth\n",
" tloss += loss.item()\n",
" \n",
" #psnr\n",
" MSE_val = MSE_LossFn(Ft_p, IFrame)\n",
" psnr += (10 * log10(1 / MSE_val.item()))\n",
" \n",
" return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "Eh1LB1ufZziF"
},
"source": [
"### Test validation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "axBjslWlot7I"
},
"outputs": [],
"source": [
"a, b, c = validate()\n",
"print(a, b, c.size())\n",
"plt.imshow(c.permute(1, 2, 0).numpy())"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "1PIFbXuKpBBe"
},
"source": [
"### Initialization"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {},
"colab_type": "code",
"id": "gWt-nlx2MSOk"
},
"outputs": [],
"source": [
"if TRAINING_CONTINUE:\n",
" dict1 = torch.load(CHECKPOINT_PATH)\n",
" ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])\n",
" flowComp.load_state_dict(dict1['state_dictFC'])\n",
"else:\n",
" dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}"
]
},
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "RbQnS_KNilbR"
},
"source": [
"### Training"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "both",
"colab": {},
"colab_type": "code",
"id": "QrAS6TmP11RW"
},
"outputs": [],
"source": [
"import time\n",
"\n",
"start = time.time()\n",
"cLoss = dict1['loss']\n",
"valLoss = dict1['valLoss']\n",
"valPSNR = dict1['valPSNR']\n",
"checkpoint_counter = 0\n",
"\n",
"### Main training loop\n",
"for epoch in range(dict1['epoch'] + 1, EPOCHS):\n",
" clear_output()\n",
" print(\"Epoch: \", epoch)\n",
" \n",
" # Plots\n",
" if (epoch):\n",
" Plot(1, cLoss, 'red')\n",
" Plot(1, valLoss, 'blue')\n",
" Plot(2, valPSNR, 'green')\n",
" display(plt.gcf())\n",
" \n",
" # Append and reset\n",
" cLoss.append([])\n",
" valLoss.append([])\n",
" valPSNR.append([])\n",
" iLoss = 0\n",
" \n",
" # Increment scheduler count \n",
" scheduler.step()\n",
" \n",
" for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0):\n",
" \n",
"\t\t## Getting the input and the target from the training set\n",
" frame0, frameT, frame1 = trainData\n",
" \n",
" I0 = frame0.to(device)\n",
" I1 = frame1.to(device)\n",
" IFrame = frameT.to(device)\n",
" \n",
" optimizer.zero_grad()\n",
" \n",
" # Calculate flow between reference frames I0 and I1\n",
" flowOut = flowComp(torch.cat((I0, I1), dim=1))\n",
" \n",
" # Extracting flows between I0 and I1 - F_0_1 and F_1_0\n",
" F_0_1 = flowOut[:,:2,:,:]\n",
" F_1_0 = flowOut[:,2:,:,:]\n",
" \n",
" fCoeff = model.getFlowCoeff(trainFrameIndex, device)\n",
" \n",
" # Calculate intermediate flows\n",
" F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0\n",
" F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0\n",
" \n",
" # Get intermediate frames from the intermediate flows\n",
" g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)\n",
" g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)\n",
" \n",
" # Calculate optical flow residuals and visibility maps\n",
" intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))\n",
" \n",
" # Extract optical flow residuals and visibility maps\n",
" F_t_0_f = intrpOut[:, :2, :, :] + F_t_0\n",
" F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1\n",
" V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :])\n",
" V_t_1 = 1 - V_t_0\n",
" \n",
" # Get intermediate frames from the intermediate flows\n",
" g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)\n",
" g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)\n",
" \n",
" wCoeff = model.getWarpCoeff(trainFrameIndex, device)\n",
" \n",
" # Calculate final intermediate frame \n",
" Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)\n",
" \n",
" # Loss\n",
" recnLoss = L1_lossFn(Ft_p, IFrame)\n",
" \n",
" prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))\n",
" \n",
" warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(trainFlowBackWarp(I1, F_0_1), I0)\n",
" \n",
" loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))\n",
" loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))\n",
" loss_smooth = loss_smooth_1_0 + loss_smooth_0_1\n",
" \n",
" # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4\n",
" # since the loss in paper is calculated for input pixels in range 0-255\n",
" # and the input to our network is in range 0-1\n",
" loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth\n",
" \n",
" # Backpropagate\n",
" loss.backward()\n",
" optimizer.step()\n",
" iLoss += loss.item()\n",
" \n",
" # Validation and progress every `PROGRESS_ITER` iterations\n",
" if ((trainIndex % PROGRESS_ITER) == PROGRESS_ITER - 1):\n",
" end = time.time()\n",
" \n",
" psnr, vLoss, valImg = validate()\n",
" \n",
" valPSNR[epoch].append(psnr)\n",
" valLoss[epoch].append(vLoss)\n",
" \n",
" #Tensorboard\n",
" itr = trainIndex + epoch * (len(trainloader))\n",
" \n",
" writer.add_scalars('Loss', {'trainLoss': iLoss/PROGRESS_ITER,\n",
" 'validationLoss': vLoss}, itr)\n",
" writer.add_scalar('PSNR', psnr, itr)\n",
" \n",
" writer.add_image('Validation',valImg , itr)\n",
" #####\n",
" \n",
" endVal = time.time()\n",
" \n",
" print(\" Loss: %0.6f Iterations: %4d/%4d TrainExecTime: %0.1f ValLoss:%0.6f ValPSNR: %0.4f ValEvalTime: %0.2f LearningRate: %f\" % (iLoss / PROGRESS_ITER, trainIndex, len(trainloader), end - start, vLoss, psnr, endVal - end, get_lr(optimizer)))\n",
" \n",
" \n",
" cLoss[epoch].append(iLoss/PROGRESS_ITER)\n",
" iLoss = 0\n",
" start = time.time()\n",
" \n",
" # Create checkpoint after every `CHECKPOINT_EPOCH` epochs\n",
" if ((epoch % CHECKPOINT_EPOCH) == CHECKPOINT_EPOCH - 1):\n",
" dict1 = {\n",
" 'Detail':\"End to end Super SloMo.\",\n",
" 'epoch':epoch,\n",
" 'timestamp':datetime.datetime.now(),\n",
" 'trainBatchSz':TRAIN_BATCH_SIZE,\n",
" 'validationBatchSz':VALIDATION_BATCH_SIZE,\n",
" 'learningRate':get_lr(optimizer),\n",
" 'loss':cLoss,\n",
" 'valLoss':valLoss,\n",
" 'valPSNR':valPSNR,\n",
" 'state_dictFC': flowComp.state_dict(),\n",
" 'state_dictAT': ArbTimeFlowIntrp.state_dict(),\n",
" }\n",
" torch.save(dict1, CHECKPOINT_DIR + \"/SuperSloMo\" + str(checkpoint_counter) + \".ckpt\")\n",
" checkpoint_counter += 1\n",
" plt.close('all')"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"collapsed_sections": [],
"name": "train.ipynb",
"provenance": [],
"version": "0.3.2"
},
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.0"
}
},
"nbformat": 4,
"nbformat_minor": 2
}