torchbenchmark/models/Background_Matting/train_real_fixed.py (157 lines of code) (raw):

from __future__ import print_function import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim from tensorboardX import SummaryWriter import os import time import argparse import numpy as np from data_loader import VideoData from functions import * from networks import ResnetConditionHR, MultiscaleDiscriminator, conv_init from loss_functions import alpha_loss, compose_loss, alpha_gradient_loss, GANloss #CUDA #os.environ["CUDA_VISIBLE_DEVICES"]="4" print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"]) """Parses arguments.""" parser = argparse.ArgumentParser(description='Training Background Matting on Adobe Dataset.') parser.add_argument('-n', '--name', type=str, help='Name of tensorboard and model saving folders.') parser.add_argument('-bs', '--batch_size', type=int, help='Batch Size.') parser.add_argument('-res', '--reso', type=int, help='Input image resolution') parser.add_argument('-init_model', '--init_model', type=str, help='Initial model file') parser.add_argument('-epoch', '--epoch', type=int, default=10,help='Maximum Epoch') parser.add_argument('-n_blocks1', '--n_blocks1', type=int, default=7,help='Number of residual blocks after Context Switching.') parser.add_argument('-n_blocks2', '--n_blocks2', type=int, default=3,help='Number of residual blocks for Fg and alpha each.') parser.add_argument('-d', '--debug', type=str, default="", help='File to dump output') parser.add_argument('-s', '--script', type=bool, default=False, help='Trace the model') args=parser.parse_args() ##Directories tb_dir='TB_Summary/' + args.name model_dir='Models/' + args.name torch.manual_seed(1337) np.random.seed(1337) torch.backends.cudnn.deterministic = True torch.backends.cudnn.benchmark = False if not os.path.exists(model_dir): os.makedirs(model_dir) if not os.path.exists(tb_dir): os.makedirs(tb_dir) ## Input list data_config_train = {'reso': (args.reso,args.reso)} #if trimap is true, rcnn is used # DATA LOADING print('\n[Phase 1] : Data Preparation') def collate_filter_none(batch): batch = list(filter(lambda x: x is not None, batch)) return torch.utils.data.dataloader.default_collate(batch) #Original Data traindata = VideoData(csv_file='Video_data_train.csv',data_config=data_config_train,transform=None) #Write a dataloader function that can read the database provided by .csv file train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none) print('\n[Phase 2] : Initialization') netB=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=args.n_blocks1,n_blocks2=args.n_blocks2) #netB=nn.DataParallel(netB) #netB.load_state_dict(torch.load(args.init_model)) netB.cuda(); netB.eval() for param in netB.parameters(): #freeze netD param.requires_grad = False netG=ResnetConditionHR(input_nc=(3,3,1,4),output_nc=4,n_blocks1=args.n_blocks1,n_blocks2=args.n_blocks2) netG.apply(conv_init) #netG=nn.DataParallel(netG) netG.cuda() torch.backends.cudnn.benchmark=True netD=MultiscaleDiscriminator(input_nc=3,num_D=1,norm_layer=nn.InstanceNorm2d,ndf=64) netD.apply(conv_init) netD=nn.DataParallel(netD) netD.cuda() #Loss l1_loss=alpha_loss() c_loss=compose_loss() g_loss=alpha_gradient_loss() GAN_loss=GANloss() optimizerG = optim.Adam(netG.parameters(), lr=1e-4) optimizerD = optim.Adam(netD.parameters(), lr=1e-5) log_writer=SummaryWriter(tb_dir) step=50 KK=len(train_loader) wt=1 print('Tracing') for data in train_loader: bg, image, seg, multi_fr = data['bg'], data['image'], data['seg'], data['multi_fr'] bg, image, seg, multi_fr = Variable(bg.cuda()), Variable(image.cuda()), Variable(seg.cuda()), Variable(multi_fr.cuda()) if args.script: netB = torch.jit.trace(netB,(image,bg,seg,multi_fr)) netG = torch.jit.trace(netG,(image,bg,seg,multi_fr)) else: netB(image,bg,seg,multi_fr) netG(image,bg,seg,multi_fr) break print('Starting training') for epoch in range(0,args.epoch): netG.train(); netD.train() lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse=0,0,0,0,0,0,0,0,0,0 t0=time.time(); for i,data in enumerate(train_loader): #Initiating bg, image, seg, multi_fr, seg_gt, back_rnd = data['bg'], data['image'], data['seg'], data['multi_fr'], data['seg-gt'], data['back-rnd'] bg, image, seg, multi_fr, seg_gt, back_rnd = Variable(bg.cuda()), Variable(image.cuda()), Variable(seg.cuda()), Variable(multi_fr.cuda()), Variable(seg_gt.cuda()), Variable(back_rnd.cuda()) mask0=Variable(torch.ones(seg.shape).cuda()) tr0=time.time() #pseudo-supervision alpha_pred_sup,fg_pred_sup=netB(image,bg,seg,multi_fr) mask=(alpha_pred_sup>-0.98).type(torch.cuda.FloatTensor) mask1=(seg_gt>0.95).type(torch.cuda.FloatTensor) ## Train Generator alpha_pred,fg_pred=netG(image,bg,seg,multi_fr) if args.debug: torch.save(fg_pred, args.debug) ##pseudo-supervised losses al_loss=l1_loss(alpha_pred_sup,alpha_pred,mask0)+0.5*g_loss(alpha_pred_sup,alpha_pred,mask0) fg_loss=l1_loss(fg_pred_sup,fg_pred,mask) #compose into same background comp_loss= c_loss(image,alpha_pred,fg_pred,bg,mask1) #randomly permute the background perm=torch.LongTensor(np.random.permutation(bg.shape[0])) bg_sh=bg[perm,:,:,:] al_mask=(alpha_pred>0.95).type(torch.cuda.FloatTensor) #Choose the target background for composition #back_rnd: contains separate set of background videos captured #bg_sh: contains randomly permuted captured background from the same minibatch if np.random.random_sample() > 0.5: bg_sh=back_rnd image_sh=compose_image_withshift(alpha_pred,image*al_mask + fg_pred*(1-al_mask),bg_sh,seg) fake_response=netD(image_sh) loss_ganG=GAN_loss(fake_response,label_type=True) lossG= loss_ganG + wt*(0.05*comp_loss+0.05*al_loss+0.05*fg_loss) optimizerG.zero_grad() lossG.backward() optimizerG.step() ##Train Discriminator fake_response=netD(image_sh); real_response=netD(image) loss_ganD_fake=GAN_loss(fake_response,label_type=False) loss_ganD_real=GAN_loss(real_response,label_type=True) lossD=(loss_ganD_real+loss_ganD_fake)*0.5 # Update discriminator for every 5 generator update if i%5 ==0: optimizerD.zero_grad() lossD.backward() optimizerD.step() lG += lossG.data lD += lossD.data GenL += loss_ganG.data DisL_r += loss_ganD_real.data DisL_f += loss_ganD_fake.data alL += al_loss.data fgL += fg_loss.data compL += comp_loss.data log_writer.add_scalar('Generator Loss', lossG.data, epoch*KK + i + 1) log_writer.add_scalar('Discriminator Loss', lossD.data, epoch*KK + i + 1) log_writer.add_scalar('Generator Loss: Fake', loss_ganG.data, epoch*KK + i + 1) log_writer.add_scalar('Discriminator Loss: Real', loss_ganD_real.data, epoch*KK + i + 1) log_writer.add_scalar('Discriminator Loss: Fake', loss_ganD_fake.data, epoch*KK + i + 1) log_writer.add_scalar('Generator Loss: Alpha', al_loss.data, epoch*KK + i + 1) log_writer.add_scalar('Generator Loss: Fg', fg_loss.data, epoch*KK + i + 1) log_writer.add_scalar('Generator Loss: Comp', comp_loss.data, epoch*KK + i + 1) t1=time.time() elapse +=t1 -t0 elapse_run += t1-tr0 t0=t1 if i % step == (step-1): print('[%d, %5d] Gen-loss: %.4f Disc-loss: %.4f Alpha-loss: %.4f Fg-loss: %.4f Comp-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' %(epoch + 1, i + 1, lG/step,lD/step,alL/step,fgL/step,compL/step,elapse/step,elapse_run/step)) lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse=0,0,0,0,0,0,0,0,0,0 write_tb_log(image,'image',log_writer,i) write_tb_log(seg,'seg',log_writer,i) write_tb_log(alpha_pred_sup,'alpha-sup',log_writer,i) write_tb_log(alpha_pred,'alpha_pred',log_writer,i) write_tb_log(fg_pred_sup*mask,'fg-pred-sup',log_writer,i) write_tb_log(fg_pred*mask,'fg_pred',log_writer,i) #composition alpha_pred=(alpha_pred+1)/2 comp=fg_pred*alpha_pred + (1-alpha_pred)*bg write_tb_log(comp,'composite-same',log_writer,i) write_tb_log(image_sh,'composite-diff',log_writer,i) del comp del mask, back_rnd, mask0, seg_gt, mask1, bg, alpha_pred, alpha_pred_sup, image, fg_pred_sup, fg_pred, seg, multi_fr,image_sh, bg_sh, fake_response, real_response, al_loss, fg_loss, comp_loss, lossG, lossD, loss_ganD_real, loss_ganD_fake, loss_ganG if (epoch%2 == 0): torch.save(netG.state_dict(), model_dir + 'netG_epoch_%d.pth' %(epoch)) torch.save(optimizerG.state_dict(), model_dir + 'optimG_epoch_%d.pth' %(epoch)) torch.save(netD.state_dict(), model_dir + 'netD_epoch_%d.pth' %(epoch)) torch.save(optimizerD.state_dict(), model_dir + 'optimD_epoch_%d.pth' %(epoch)) #Change weight every 2 epoch to put more stress on discriminator weight and less on pseudo-supervision wt=wt/2