torchbenchmark/models/Background_Matting/train_adobe.py (104 lines of code) (raw):
from __future__ import print_function
import torch
from torch.autograd import Variable
import torch.nn as nn
import torch.optim as optim
from tensorboardX import SummaryWriter
import os
import time
import argparse
from data_loader import AdobeDataAffineHR
from functions import *
from networks import ResnetConditionHR, conv_init
from loss_functions import alpha_loss, compose_loss, alpha_gradient_loss
#CUDA
#os.environ["CUDA_VISIBLE_DEVICES"]="4"
print('CUDA Device: ' + os.environ["CUDA_VISIBLE_DEVICES"])
"""Parses arguments."""
parser = argparse.ArgumentParser(description='Training Background Matting on Adobe Dataset.')
parser.add_argument('-n', '--name', type=str, help='Name of tensorboard and model saving folders.')
parser.add_argument('-bs', '--batch_size', type=int, help='Batch Size.')
parser.add_argument('-res', '--reso', type=int, help='Input image resolution')
parser.add_argument('-epoch', '--epoch', type=int, default=60,help='Maximum Epoch')
parser.add_argument('-n_blocks1', '--n_blocks1', type=int, default=7,help='Number of residual blocks after Context Switching.')
parser.add_argument('-n_blocks2', '--n_blocks2', type=int, default=3,help='Number of residual blocks for Fg and alpha each.')
args=parser.parse_args()
##Directories
tb_dir='TB_Summary/' + args.name
model_dir='Models/' + args.name
if not os.path.exists(model_dir):
os.makedirs(model_dir)
if not os.path.exists(tb_dir):
os.makedirs(tb_dir)
## Input list
data_config_train = {'reso': [args.reso,args.reso], 'trimapK': [5,5], 'noise': True} # choice for data loading parameters
# DATA LOADING
print('\n[Phase 1] : Data Preparation')
def collate_filter_none(batch):
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
#Original Data
traindata = AdobeDataAffineHR(csv_file='Data_adobe/Adobe_train_data.csv',data_config=data_config_train,transform=None) #Write a dataloader function that can read the database provided by .csv file
train_loader = torch.utils.data.DataLoader(traindata, batch_size=args.batch_size, shuffle=True, num_workers=args.batch_size, collate_fn=collate_filter_none)
print('\n[Phase 2] : Initialization')
net=ResnetConditionHR(input_nc=(3,3,1,4), output_nc=4, n_blocks1=7, n_blocks2=3, norm_layer=nn.BatchNorm2d)
net.apply(conv_init)
net=nn.DataParallel(net)
#net.load_state_dict(torch.load(model_dir + 'net_epoch_X')) #uncomment this if you are initializing your model
net.cuda()
torch.backends.cudnn.benchmark=True
#Loss
l1_loss=alpha_loss()
c_loss=compose_loss()
g_loss=alpha_gradient_loss()
optimizer = optim.Adam(net.parameters(), lr=1e-4)
#optimizer.load_state_dict(torch.load(model_dir + 'optim_epoch_X')) #uncomment this if you are initializing your model
log_writer=SummaryWriter(tb_dir)
print('Starting Training')
step=50 #steps to visualize training images in tensorboard
KK=len(train_loader)
for epoch in range(0,args.epoch):
net.train();
netL, alL, fgL, fg_cL, al_fg_cL, elapse_run, elapse=0,0,0,0,0,0,0
t0=time.time();
testL=0; ct_tst=0;
for i,data in enumerate(train_loader):
#Initiating
fg, bg, alpha, image, seg, bg_tr, multi_fr = data['fg'], data['bg'], data['alpha'], data['image'], data['seg'], data['bg_tr'], data['multi_fr']
fg, bg, alpha, image, seg, bg_tr, multi_fr = Variable(fg.cuda()), Variable(bg.cuda()), Variable(alpha.cuda()), Variable(image.cuda()), Variable(seg.cuda()), Variable(bg_tr.cuda()), Variable(multi_fr.cuda())
mask=(alpha>-0.99).type(torch.cuda.FloatTensor)
mask0=Variable(torch.ones(alpha.shape).cuda())
tr0=time.time()
alpha_pred,fg_pred=net(image,bg_tr,seg,multi_fr)
## Put needed loss here
al_loss=l1_loss(alpha,alpha_pred,mask0)
fg_loss=l1_loss(fg,fg_pred,mask)
al_mask=(alpha_pred>0.95).type(torch.cuda.FloatTensor)
fg_pred_c=image*al_mask + fg_pred*(1-al_mask)
fg_c_loss= c_loss(image,alpha_pred,fg_pred_c,bg,mask0)
al_fg_c_loss=g_loss(alpha,alpha_pred,mask0)
loss=al_loss + 2*fg_loss + fg_c_loss + al_fg_c_loss
optimizer.zero_grad()
loss.backward()
optimizer.step()
netL += loss.data
alL += al_loss.data
fgL += fg_loss.data
fg_cL += fg_c_loss.data
al_fg_cL += al_fg_c_loss.data
log_writer.add_scalar('training_loss', loss.data, epoch*KK + i + 1)
log_writer.add_scalar('alpha_loss', al_loss.data, epoch*KK + i + 1)
log_writer.add_scalar('fg_loss', fg_loss.data, epoch*KK + i + 1)
log_writer.add_scalar('comp_loss', fg_c_loss.data, epoch*KK + i + 1)
log_writer.add_scalar('alpha_gradient_loss', al_fg_c_loss.data, epoch*KK + i + 1)
t1=time.time()
elapse +=t1 -t0
elapse_run += t1-tr0
t0=t1
testL+=loss.data
ct_tst+=1
if i % step == (step-1):
print('[%d, %5d] Total-loss: %.4f Alpha-loss: %.4f Fg-loss: %.4f Comp-loss: %.4f Alpha-gradient-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, netL/step, alL/step, fgL/step, fg_cL/step, al_fg_cL/step, elapse/step, elapse_run/step))
netL, alL, fgL, fg_cL, al_fg_cL, elapse_run, elapse=0,0,0,0,0,0,0
write_tb_log(image,'image',log_writer,i)
write_tb_log(seg,'seg',log_writer,i)
write_tb_log(alpha,'alpha',log_writer,i)
write_tb_log(alpha_pred,'alpha_pred',log_writer,i)
write_tb_log(fg*mask,'fg',log_writer,i)
write_tb_log(fg_pred*mask,'fg_pred',log_writer,i)
write_tb_log(multi_fr[0:4,0,...].unsqueeze(1),'multi_fr',log_writer,i)
#composition
alpha_pred=(alpha_pred+1)/2
comp=fg_pred*alpha_pred + (1-alpha_pred)*bg
write_tb_log(comp,'composite',log_writer,i)
del comp
del fg, bg, alpha, image, alpha_pred, fg_pred, seg, multi_fr
#Saving
torch.save(net.state_dict(), model_dir + 'net_epoch_%d_%.4f.pth' %(epoch,testL/ct_tst))
torch.save(optimizer.state_dict(), model_dir + 'optim_epoch_%d_%.4f.pth' %(epoch,testL/ct_tst))