torchbenchmark/models/Super_SloMo/train.ipynb (689 lines of code) (raw):

{ "cells": [ { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "f_KNv25DX7B6" }, "source": [ "#[Super SloMo](https://people.cs.umass.edu/~hzjiang/projects/superslomo/)\n", "##High Quality Estimation of Multiple Intermediate Frames for Video Interpolation\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "0VWuBGh6zMMZ" }, "outputs": [], "source": [ "%matplotlib inline\n", "import torch\n", "import torchvision\n", "import torchvision.transforms as transforms\n", "import torch.optim as optim\n", "import torch.nn as nn\n", "import torch.nn.functional as F\n", "import slomo_model as model\n", "import dataloader\n", "import matplotlib.pyplot as plt\n", "from math import log10\n", "from IPython.display import clear_output, display\n", "import datetime\n", "from tensorboardX import SummaryWriter" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1VynXmoKp_3M" }, "source": [ "##Parameters" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "N2yrOVZjqDe9" }, "outputs": [], "source": [ "# Learning Rate. Set `MILESTONES` to epoch values where you want to decrease\n", "# learning rate by a factor of 0.1\n", "INITIAL_LEARNING_RATE = 0.0001\n", "MILESTONES = [100, 150]\n", "\n", "# Number of epochs to train\n", "EPOCHS = 200\n", "\n", "# Choose batchsize as per GPU/CPU configuration\n", "# This configuration works on GTX 1080 Ti\n", "TRAIN_BATCH_SIZE = 6\n", "VALIDATION_BATCH_SIZE = 10\n", "\n", "# Path to dataset folder containing train-test-validation folders\n", "DATASET_ROOT = \"path/to/dataset\"\n", "\n", "# Path to folder for saving checkpoints\n", "CHECKPOINT_DIR = 'path/to/checkpoint_directory'\n", "\n", "# If resuming from checkpoint, set `trainingContinue` to True and set `checkpoint_path`\n", "TRAINING_CONTINUE = False\n", "CHECKPOINT_PATH = 'path/to/checkpoint/file'\n", "\n", "# Progress and validation frequency (N: after every N iterations)\n", "PROGRESS_ITER = 100\n", "\n", "# Checkpoint frequency (N: after every N epochs). Each checkpoint is roughly of size 151 MB.\n", "CHECKPOINT_EPOCH = 5" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Yr3Lm1ovbWv1" }, "source": [ "##[TensorboardX](https://github.com/lanpa/tensorboardX)\n", "### For visualizing loss and interpolated frames" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "saUJTMiMCAzH" }, "outputs": [], "source": [ "writer = SummaryWriter('log')" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Ua1DJm82aj5-" }, "source": [ "###Initialize flow computation and arbitrary-time flow interpolation CNNs." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "D42vzEKrWtpG" }, "outputs": [], "source": [ "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n", "flowComp = model.UNet(6, 4)\n", "flowComp.to(device)\n", "ArbTimeFlowIntrp = model.UNet(20, 5)\n", "ArbTimeFlowIntrp.to(device)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "UYMpk2EYchaY" }, "source": [ "###Initialze backward warpers for train and validation datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "vJq6SrWIf2GE" }, "outputs": [], "source": [ "trainFlowBackWarp = model.backWarp(352, 352, device)\n", "trainFlowBackWarp = trainFlowBackWarp.to(device)\n", "validationFlowBackWarp = model.backWarp(640, 352, device)\n", "validationFlowBackWarp = validationFlowBackWarp.to(device)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "oSs9UaIjdTT2" }, "source": [ "###Load Datasets" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "MJ9cVigEgtyT" }, "outputs": [], "source": [ "# Channel wise mean calculated on adobe240-fps training dataset\n", "mean = [0.429, 0.431, 0.397]\n", "std = [1, 1, 1]\n", "normalize = transforms.Normalize(mean=mean,\n", " std=std)\n", "transform = transforms.Compose([transforms.ToTensor(), normalize])\n", "\n", "trainset = dataloader.SuperSloMo(root=DATASET_ROOT + '/train', transform=transform, train=True)\n", "trainloader = torch.utils.data.DataLoader(trainset, batch_size=TRAIN_BATCH_SIZE, shuffle=True)\n", "\n", "validationset = dataloader.SuperSloMo(root=DATASET_ROOT + '/validation', transform=transform, randomCropSize=(640, 352), train=False)\n", "validationloader = torch.utils.data.DataLoader(validationset, batch_size=VALIDATION_BATCH_SIZE, shuffle=False)\n", "\n", "print(trainset, validationset)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "WXmNMdbJfp2d" }, "source": [ "###Create transform to display image from tensor" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "try3adPHgwse" }, "outputs": [], "source": [ "negmean = [x * -1 for x in mean]\n", "revNormalize = transforms.Normalize(mean=negmean, std=std)\n", "TP = transforms.Compose([revNormalize, transforms.ToPILImage()])" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "32XZg9Mfd5bN" }, "source": [ "###Test the dataset" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "0Vyf7dbwCO1E" }, "outputs": [], "source": [ "for trainIndex, (trainData, frameIndex) in enumerate(trainloader, 0):\n", " frame0, frameT, frame1 = trainData\n", " print(\"Intermediate frame index: \", (frameIndex[0]))\n", " plt.imshow(TP(frame0[0]))\n", " plt.grid(True)\n", " plt.figure()\n", " plt.imshow(TP(frameT[0]))\n", " plt.grid(True)\n", " plt.figure()\n", " plt.imshow(TP(frame1[0]))\n", " plt.grid(True)\n", " break" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "rh0MK2qKuBlV" }, "source": [ "###Utils" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "BdMFU0ijfIuI" }, "outputs": [], "source": [ "plt.rcParams['figure.figsize'] = [15, 3]\n", "def Plot(num, listInp, d):\n", " a = listInp\n", " c = []\n", " for b in a:\n", " c.append(sum(b)/len(b))\n", " plt.subplot(1, 2, num)\n", " plt.plot(c, color=d)\n", " plt.grid(True)\n", " \n", "def get_lr(optimizer):\n", " for param_group in optimizer.param_groups:\n", " return param_group['lr']" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "mooLcmxtpPR_" }, "source": [ "###Loss and Optimizer" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "BuWQfcb-jhWx" }, "outputs": [], "source": [ "L1_lossFn = nn.L1Loss()\n", "MSE_LossFn = nn.MSELoss()\n", "\n", "params = list(ArbTimeFlowIntrp.parameters()) + list(flowComp.parameters())\n", "\n", "optimizer = optim.Adam(params, lr=INITIAL_LEARNING_RATE)\n", "# scheduler to decrease learning rate by a factor of 10 at milestones.\n", "scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=MILESTONES, gamma=0.1)" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "a5rIkwwfpk1n" }, "source": [ "###Initializing VGG16 model for perceptual loss" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "9WR_NxHP51oB" }, "outputs": [], "source": [ "vgg16 = torchvision.models.vgg16(pretrained=True)\n", "vgg16_conv_4_3 = nn.Sequential(*list(vgg16.children())[0][:22])\n", "vgg16_conv_4_3.to(device)\n", "for param in vgg16_conv_4_3.parameters():\n", "\t\tparam.requires_grad = False" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "9-6wLaBJZqsm" }, "source": [ "### Validation function\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "RhMMZ_I4iDFf" }, "outputs": [], "source": [ "def validate():\n", " # For details see training.\n", " psnr = 0\n", " tloss = 0\n", " flag = 1\n", " with torch.no_grad():\n", " for validationIndex, (validationData, validationFrameIndex) in enumerate(validationloader, 0):\n", " frame0, frameT, frame1 = validationData\n", "\n", " I0 = frame0.to(device)\n", " I1 = frame1.to(device)\n", " IFrame = frameT.to(device)\n", " \n", " \n", " flowOut = flowComp(torch.cat((I0, I1), dim=1))\n", " F_0_1 = flowOut[:,:2,:,:]\n", " F_1_0 = flowOut[:,2:,:,:]\n", "\n", " fCoeff = model.getFlowCoeff(validationFrameIndex, device)\n", "\n", " F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0\n", " F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0\n", "\n", " g_I0_F_t_0 = validationFlowBackWarp(I0, F_t_0)\n", " g_I1_F_t_1 = validationFlowBackWarp(I1, F_t_1)\n", " \n", " intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))\n", " \n", " F_t_0_f = intrpOut[:, :2, :, :] + F_t_0\n", " F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1\n", " V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :])\n", " V_t_1 = 1 - V_t_0\n", " \n", " g_I0_F_t_0_f = validationFlowBackWarp(I0, F_t_0_f)\n", " g_I1_F_t_1_f = validationFlowBackWarp(I1, F_t_1_f)\n", " \n", " wCoeff = model.getWarpCoeff(validationFrameIndex, device)\n", " \n", " Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)\n", " \n", " # For tensorboard\n", " if (flag):\n", " retImg = torchvision.utils.make_grid([revNormalize(frame0[0]), revNormalize(frameT[0]), revNormalize(Ft_p.cpu()[0]), revNormalize(frame1[0])], padding=10)\n", " flag = 0\n", " \n", " \n", " #loss\n", " recnLoss = L1_lossFn(Ft_p, IFrame)\n", " \n", " prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))\n", " \n", " warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(validationFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(validationFlowBackWarp(I1, F_0_1), I0)\n", " \n", " loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))\n", " loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))\n", " loss_smooth = loss_smooth_1_0 + loss_smooth_0_1\n", " \n", " \n", " loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth\n", " tloss += loss.item()\n", " \n", " #psnr\n", " MSE_val = MSE_LossFn(Ft_p, IFrame)\n", " psnr += (10 * log10(1 / MSE_val.item()))\n", " \n", " return (psnr / len(validationloader)), (tloss / len(validationloader)), retImg" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Eh1LB1ufZziF" }, "source": [ "### Test validation" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "axBjslWlot7I" }, "outputs": [], "source": [ "a, b, c = validate()\n", "print(a, b, c.size())\n", "plt.imshow(c.permute(1, 2, 0).numpy())" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "1PIFbXuKpBBe" }, "source": [ "### Initialization" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": {}, "colab_type": "code", "id": "gWt-nlx2MSOk" }, "outputs": [], "source": [ "if TRAINING_CONTINUE:\n", " dict1 = torch.load(CHECKPOINT_PATH)\n", " ArbTimeFlowIntrp.load_state_dict(dict1['state_dictAT'])\n", " flowComp.load_state_dict(dict1['state_dictFC'])\n", "else:\n", " dict1 = {'loss': [], 'valLoss': [], 'valPSNR': [], 'epoch': -1}" ] }, { "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "RbQnS_KNilbR" }, "source": [ "### Training" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "cellView": "both", "colab": {}, "colab_type": "code", "id": "QrAS6TmP11RW" }, "outputs": [], "source": [ "import time\n", "\n", "start = time.time()\n", "cLoss = dict1['loss']\n", "valLoss = dict1['valLoss']\n", "valPSNR = dict1['valPSNR']\n", "checkpoint_counter = 0\n", "\n", "### Main training loop\n", "for epoch in range(dict1['epoch'] + 1, EPOCHS):\n", " clear_output()\n", " print(\"Epoch: \", epoch)\n", " \n", " # Plots\n", " if (epoch):\n", " Plot(1, cLoss, 'red')\n", " Plot(1, valLoss, 'blue')\n", " Plot(2, valPSNR, 'green')\n", " display(plt.gcf())\n", " \n", " # Append and reset\n", " cLoss.append([])\n", " valLoss.append([])\n", " valPSNR.append([])\n", " iLoss = 0\n", " \n", " # Increment scheduler count \n", " scheduler.step()\n", " \n", " for trainIndex, (trainData, trainFrameIndex) in enumerate(trainloader, 0):\n", " \n", "\t\t## Getting the input and the target from the training set\n", " frame0, frameT, frame1 = trainData\n", " \n", " I0 = frame0.to(device)\n", " I1 = frame1.to(device)\n", " IFrame = frameT.to(device)\n", " \n", " optimizer.zero_grad()\n", " \n", " # Calculate flow between reference frames I0 and I1\n", " flowOut = flowComp(torch.cat((I0, I1), dim=1))\n", " \n", " # Extracting flows between I0 and I1 - F_0_1 and F_1_0\n", " F_0_1 = flowOut[:,:2,:,:]\n", " F_1_0 = flowOut[:,2:,:,:]\n", " \n", " fCoeff = model.getFlowCoeff(trainFrameIndex, device)\n", " \n", " # Calculate intermediate flows\n", " F_t_0 = fCoeff[0] * F_0_1 + fCoeff[1] * F_1_0\n", " F_t_1 = fCoeff[2] * F_0_1 + fCoeff[3] * F_1_0\n", " \n", " # Get intermediate frames from the intermediate flows\n", " g_I0_F_t_0 = trainFlowBackWarp(I0, F_t_0)\n", " g_I1_F_t_1 = trainFlowBackWarp(I1, F_t_1)\n", " \n", " # Calculate optical flow residuals and visibility maps\n", " intrpOut = ArbTimeFlowIntrp(torch.cat((I0, I1, F_0_1, F_1_0, F_t_1, F_t_0, g_I1_F_t_1, g_I0_F_t_0), dim=1))\n", " \n", " # Extract optical flow residuals and visibility maps\n", " F_t_0_f = intrpOut[:, :2, :, :] + F_t_0\n", " F_t_1_f = intrpOut[:, 2:4, :, :] + F_t_1\n", " V_t_0 = F.sigmoid(intrpOut[:, 4:5, :, :])\n", " V_t_1 = 1 - V_t_0\n", " \n", " # Get intermediate frames from the intermediate flows\n", " g_I0_F_t_0_f = trainFlowBackWarp(I0, F_t_0_f)\n", " g_I1_F_t_1_f = trainFlowBackWarp(I1, F_t_1_f)\n", " \n", " wCoeff = model.getWarpCoeff(trainFrameIndex, device)\n", " \n", " # Calculate final intermediate frame \n", " Ft_p = (wCoeff[0] * V_t_0 * g_I0_F_t_0_f + wCoeff[1] * V_t_1 * g_I1_F_t_1_f) / (wCoeff[0] * V_t_0 + wCoeff[1] * V_t_1)\n", " \n", " # Loss\n", " recnLoss = L1_lossFn(Ft_p, IFrame)\n", " \n", " prcpLoss = MSE_LossFn(vgg16_conv_4_3(Ft_p), vgg16_conv_4_3(IFrame))\n", " \n", " warpLoss = L1_lossFn(g_I0_F_t_0, IFrame) + L1_lossFn(g_I1_F_t_1, IFrame) + L1_lossFn(trainFlowBackWarp(I0, F_1_0), I1) + L1_lossFn(trainFlowBackWarp(I1, F_0_1), I0)\n", " \n", " loss_smooth_1_0 = torch.mean(torch.abs(F_1_0[:, :, :, :-1] - F_1_0[:, :, :, 1:])) + torch.mean(torch.abs(F_1_0[:, :, :-1, :] - F_1_0[:, :, 1:, :]))\n", " loss_smooth_0_1 = torch.mean(torch.abs(F_0_1[:, :, :, :-1] - F_0_1[:, :, :, 1:])) + torch.mean(torch.abs(F_0_1[:, :, :-1, :] - F_0_1[:, :, 1:, :]))\n", " loss_smooth = loss_smooth_1_0 + loss_smooth_0_1\n", " \n", " # Total Loss - Coefficients 204 and 102 are used instead of 0.8 and 0.4\n", " # since the loss in paper is calculated for input pixels in range 0-255\n", " # and the input to our network is in range 0-1\n", " loss = 204 * recnLoss + 102 * warpLoss + 0.005 * prcpLoss + loss_smooth\n", " \n", " # Backpropagate\n", " loss.backward()\n", " optimizer.step()\n", " iLoss += loss.item()\n", " \n", " # Validation and progress every `PROGRESS_ITER` iterations\n", " if ((trainIndex % PROGRESS_ITER) == PROGRESS_ITER - 1):\n", " end = time.time()\n", " \n", " psnr, vLoss, valImg = validate()\n", " \n", " valPSNR[epoch].append(psnr)\n", " valLoss[epoch].append(vLoss)\n", " \n", " #Tensorboard\n", " itr = trainIndex + epoch * (len(trainloader))\n", " \n", " writer.add_scalars('Loss', {'trainLoss': iLoss/PROGRESS_ITER,\n", " 'validationLoss': vLoss}, itr)\n", " writer.add_scalar('PSNR', psnr, itr)\n", " \n", " writer.add_image('Validation',valImg , itr)\n", " #####\n", " \n", " endVal = time.time()\n", " \n", " print(\" Loss: %0.6f Iterations: %4d/%4d TrainExecTime: %0.1f ValLoss:%0.6f ValPSNR: %0.4f ValEvalTime: %0.2f LearningRate: %f\" % (iLoss / PROGRESS_ITER, trainIndex, len(trainloader), end - start, vLoss, psnr, endVal - end, get_lr(optimizer)))\n", " \n", " \n", " cLoss[epoch].append(iLoss/PROGRESS_ITER)\n", " iLoss = 0\n", " start = time.time()\n", " \n", " # Create checkpoint after every `CHECKPOINT_EPOCH` epochs\n", " if ((epoch % CHECKPOINT_EPOCH) == CHECKPOINT_EPOCH - 1):\n", " dict1 = {\n", " 'Detail':\"End to end Super SloMo.\",\n", " 'epoch':epoch,\n", " 'timestamp':datetime.datetime.now(),\n", " 'trainBatchSz':TRAIN_BATCH_SIZE,\n", " 'validationBatchSz':VALIDATION_BATCH_SIZE,\n", " 'learningRate':get_lr(optimizer),\n", " 'loss':cLoss,\n", " 'valLoss':valLoss,\n", " 'valPSNR':valPSNR,\n", " 'state_dictFC': flowComp.state_dict(),\n", " 'state_dictAT': ArbTimeFlowIntrp.state_dict(),\n", " }\n", " torch.save(dict1, CHECKPOINT_DIR + \"/SuperSloMo\" + str(checkpoint_counter) + \".ckpt\")\n", " checkpoint_counter += 1\n", " plt.close('all')" ] } ], "metadata": { "accelerator": "GPU", "colab": { "collapsed_sections": [], "name": "train.ipynb", "provenance": [], "version": "0.3.2" }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.0" } }, "nbformat": 4, "nbformat_minor": 2 }