torchbenchmark/models/Background_Matting/train_real_fixed.py (157 lines of code) (raw):
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
import os
import time
import argparse
import numpy as np
from data_loader import VideoData
from functions import *
from networks import ResnetConditionHR, MultiscaleDiscriminator, conv_init
from loss_functions import alpha_loss, compose_loss, alpha_gradient_loss, GANloss
#CUDA
#os.environ["CUDA_VISIBLE_DEVICES"]="4"
print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
"""Parses arguments."""
parser = argparse.ArgumentParser(description='Training Background Matting on Adobe Dataset.')
parser.add_argument('-n', '--name', type=str, help='Name of tensorboard and model saving folders.')
parser.add_argument('-bs', '--batch_size', type=int, help='Batch Size.')
parser.add_argument('-res', '--reso', type=int, help='Input image resolution')
parser.add_argument('-init_model', '--init_model', type=str, help='Initial model file')
parser.add_argument('-epoch', '--epoch', type=int, default=10,help='Maximum Epoch')
parser.add_argument('-n_blocks1', '--n_blocks1', type=int, default=7,help='Number of residual blocks after Context Switching.')
parser.add_argument('-n_blocks2', '--n_blocks2', type=int, default=3,help='Number of residual blocks for Fg and alpha each.')
parser.add_argument('-d', '--debug', type=str, default="", help='File to dump output')
parser.add_argument('-s', '--script', type=bool, default=False, help='Trace the model')
args=parser.parse_args()
##Directories
tb_dir='TB_Summary/' + args.name
model_dir='Models/' + args.name
torch.manual_seed(1337)
np.random.seed(1337)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if not os.path.exists(tb_dir):
os.makedirs(tb_dir)
## Input list
data_config_train = {'reso': (args.reso,args.reso)} #if trimap is true, rcnn is used
# DATA LOADING
print('\n[Phase 1] : Data Preparation')
def collate_filter_none(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
#Original Data
traindata = VideoData(csv_file='Video_data_train.csv',data_config=data_config_train,transform=None) #Write a dataloader function that can read the database provided by .csv file
train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none)
print('\n[Phase 2] : Initialization')
netB=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=args.n_blocks1,n_blocks2=args.n_blocks2)
#netB=nn.DataParallel(netB)
#netB.load_state_dict(torch.load(args.init_model))
netB.cuda(); netB.eval()
for param in netB.parameters(): #freeze netD
param.requires_grad = False
netG=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=args.n_blocks1,n_blocks2=args.n_blocks2)
netG.apply(conv_init)
#netG=nn.DataParallel(netG)
netG.cuda()
torch.backends.cudnn.benchmark=True
netD=MultiscaleDiscriminator(input_nc=3,num_D=1,norm_layer=nn.InstanceNorm2d,ndf=64)
netD.apply(conv_init)
netD=nn.DataParallel(netD)
netD.cuda()
#Loss
l1_loss=alpha_loss()
c_loss=compose_loss()
g_loss=alpha_gradient_loss()
GAN_loss=GANloss()
optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
optimizerD = optim.Adam(netD.parameters(), lr=1e-5)
log_writer=SummaryWriter(tb_dir)
step=50
KK=len(train_loader)
wt=1
print('Tracing')
for data in train_loader:
bg, image, seg, multi_fr = data['bg'], data['image'], data['seg'], data['multi_fr']
bg, image, seg, multi_fr = Variable(bg.cuda()), Variable(image.cuda()), Variable(seg.cuda()), Variable(multi_fr.cuda())
if args.script:
netB = torch.jit.trace(netB,(image,bg,seg,multi_fr))
netG = torch.jit.trace(netG,(image,bg,seg,multi_fr))
else:
netB(image,bg,seg,multi_fr)
netG(image,bg,seg,multi_fr)
break
print('Starting training')
for epoch in range(0,args.epoch):
netG.train(); netD.train()
lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse=0,0,0,0,0,0,0,0,0,0
t0=time.time();
for i,data in enumerate(train_loader):
#Initiating
bg, image, seg, multi_fr, seg_gt, back_rnd = data['bg'], data['image'], data['seg'], data['multi_fr'], data['seg-gt'], data['back-rnd']
bg, image, seg, multi_fr, seg_gt, back_rnd = Variable(bg.cuda()), Variable(image.cuda()), Variable(seg.cuda()), Variable(multi_fr.cuda()), Variable(seg_gt.cuda()), Variable(back_rnd.cuda())
mask0=Variable(torch.ones(seg.shape).cuda())
tr0=time.time()
#pseudo-supervision
alpha_pred_sup,fg_pred_sup=netB(image,bg,seg,multi_fr)
mask=(alpha_pred_sup>-0.98).type(torch.cuda.FloatTensor)
mask1=(seg_gt>0.95).type(torch.cuda.FloatTensor)
## Train Generator
alpha_pred,fg_pred=netG(image,bg,seg,multi_fr)
if args.debug:
torch.save(fg_pred, args.debug)
##pseudo-supervised losses
al_loss=l1_loss(alpha_pred_sup,alpha_pred,mask0)+0.5*g_loss(alpha_pred_sup,alpha_pred,mask0)
fg_loss=l1_loss(fg_pred_sup,fg_pred,mask)
#compose into same background
comp_loss= c_loss(image,alpha_pred,fg_pred,bg,mask1)
#randomly permute the background
perm=torch.LongTensor(np.random.permutation(bg.shape[0]))
bg_sh=bg[perm,:,:,:]
al_mask=(alpha_pred>0.95).type(torch.cuda.FloatTensor)
#Choose the target background for composition
#back_rnd: contains separate set of background videos captured
#bg_sh: contains randomly permuted captured background from the same minibatch
if np.random.random_sample() > 0.5:
bg_sh=back_rnd
image_sh=compose_image_withshift(alpha_pred,image*al_mask + fg_pred*(1-al_mask),bg_sh,seg)
fake_response=netD(image_sh)
loss_ganG=GAN_loss(fake_response,label_type=True)
lossG= loss_ganG + wt*(0.05*comp_loss+0.05*al_loss+0.05*fg_loss)
optimizerG.zero_grad()
lossG.backward()
optimizerG.step()
##Train Discriminator
fake_response=netD(image_sh); real_response=netD(image)
loss_ganD_fake=GAN_loss(fake_response,label_type=False)
loss_ganD_real=GAN_loss(real_response,label_type=True)
lossD=(loss_ganD_real+loss_ganD_fake)*0.5
# Update discriminator for every 5 generator update
if i%5 ==0:
optimizerD.zero_grad()
lossD.backward()
optimizerD.step()
lG += lossG.data
lD += lossD.data
GenL += loss_ganG.data
DisL_r += loss_ganD_real.data
DisL_f += loss_ganD_fake.data
alL += al_loss.data
fgL += fg_loss.data
compL += comp_loss.data
log_writer.add_scalar('Generator Loss', lossG.data, epoch*KK + i + 1)
log_writer.add_scalar('Discriminator Loss', lossD.data, epoch*KK + i + 1)
log_writer.add_scalar('Generator Loss: Fake', loss_ganG.data, epoch*KK + i + 1)
log_writer.add_scalar('Discriminator Loss: Real', loss_ganD_real.data, epoch*KK + i + 1)
log_writer.add_scalar('Discriminator Loss: Fake', loss_ganD_fake.data, epoch*KK + i + 1)
log_writer.add_scalar('Generator Loss: Alpha', al_loss.data, epoch*KK + i + 1)
log_writer.add_scalar('Generator Loss: Fg', fg_loss.data, epoch*KK + i + 1)
log_writer.add_scalar('Generator Loss: Comp', comp_loss.data, epoch*KK + i + 1)
t1=time.time()
elapse +=t1 -t0
elapse_run += t1-tr0
t0=t1
if i % step == (step-1):
print('[%d, %5d] Gen-loss: %.4f Disc-loss: %.4f Alpha-loss: %.4f Fg-loss: %.4f Comp-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' %(epoch + 1, i + 1, lG/step,lD/step,alL/step,fgL/step,compL/step,elapse/step,elapse_run/step))
lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse=0,0,0,0,0,0,0,0,0,0
write_tb_log(image,'image',log_writer,i)
write_tb_log(seg,'seg',log_writer,i)
write_tb_log(alpha_pred_sup,'alpha-sup',log_writer,i)
write_tb_log(alpha_pred,'alpha_pred',log_writer,i)
write_tb_log(fg_pred_sup*mask,'fg-pred-sup',log_writer,i)
write_tb_log(fg_pred*mask,'fg_pred',log_writer,i)
#composition
alpha_pred=(alpha_pred+1)/2
comp=fg_pred*alpha_pred + (1-alpha_pred)*bg
write_tb_log(comp,'composite-same',log_writer,i)
write_tb_log(image_sh,'composite-diff',log_writer,i)
del comp
del mask, back_rnd, mask0, seg_gt, mask1, bg, alpha_pred, alpha_pred_sup, image, fg_pred_sup, fg_pred, seg, multi_fr,image_sh, bg_sh, fake_response, real_response, al_loss, fg_loss, comp_loss, lossG, lossD, loss_ganD_real, loss_ganD_fake, loss_ganG
if (epoch%2 == 0):
torch.save(netG.state_dict(), model_dir + 'netG_epoch_%d.pth' %(epoch))
torch.save(optimizerG.state_dict(), model_dir + 'optimG_epoch_%d.pth' %(epoch))
torch.save(netD.state_dict(), model_dir + 'netD_epoch_%d.pth' %(epoch))
torch.save(optimizerD.state_dict(), model_dir + 'optimD_epoch_%d.pth' %(epoch))
#Change weight every 2 epoch to put more stress on discriminator weight and less on pseudo-supervision
wt=wt/2