torchbenchmark/models/Background_Matting/__init__.py (230 lines of code) (raw):

import os import time from argparse import Namespace import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim from tensorboardX import SummaryWriter from .data_loader import VideoData from .functions import compose_image_withshift, write_tb_log from .networks import ResnetConditionHR, MultiscaleDiscriminator, conv_init from .loss_functions import alpha_loss, compose_loss, alpha_gradient_loss, GANloss import random import numpy as np from pathlib import Path from ...util.model import BenchmarkModel from torchbenchmark.tasks import COMPUTER_VISION torch.backends.cudnn.deterministic = False torch.backends.cudnn.benchmark = True def _collate_filter_none(batch): batch = list(filter(lambda x: x is not None, batch)) return torch.utils.data.dataloader.default_collate(batch) def _create_data_dir(): data_dir = Path(__file__).parent.joinpath(".data") data_dir.mkdir(parents=True, exist_ok=True) return data_dir class Model(BenchmarkModel): task = COMPUTER_VISION.OTHER_COMPUTER_VISION # Original btach size: 4 # Original hardware: unknown # Source: https://arxiv.org/pdf/2004.00626.pdf DEFAULT_TRAIN_BSIZE = 4 DEFAULT_EVAL_BSIZE = 1 ALLOW_CUSTOMIZE_BSIZE = False def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]): super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args) self.opt = Namespace(**{ 'n_blocks1': 7, 'n_blocks2': 3, 'batch_size': self.batch_size, 'resolution': 512, 'name': 'Real_fixed' }) scriptdir = os.path.dirname(os.path.realpath(__file__)) csv_file_path = _create_data_dir().joinpath("Video_data_train_processed.csv") root = str(Path(__file__).parent) with open(f"{root}/Video_data_train.csv", "r") as r: with open(csv_file_path, "w") as w: w.write(r.read().format(scriptdir=scriptdir)) data_config_train = { 'reso': (self.opt.resolution, self.opt.resolution)} traindata = VideoData(csv_file=csv_file_path, data_config=data_config_train, transform=None) train_loader = torch.utils.data.DataLoader( traindata, batch_size=self.opt.batch_size, shuffle=True, num_workers=0, collate_fn=_collate_filter_none) self.train_data = [] for data in train_loader: self.train_data.append(data) if device == 'cuda': for key in data: data[key].cuda() netB = ResnetConditionHR(input_nc=( 3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2) if self.device == 'cuda': netB.cuda() netB.eval() for param in netB.parameters(): # freeze netB param.requires_grad = False self.netB = netB netG = ResnetConditionHR(input_nc=( 3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2) netG.apply(conv_init) self.netG = netG if self.device == 'cuda': self.netG.cuda() # TODO(asuhan): is this needed? torch.backends.cudnn.benchmark = True netD = MultiscaleDiscriminator( input_nc=3, num_D=1, norm_layer=nn.InstanceNorm2d, ndf=64) netD.apply(conv_init) # netD = nn.DataParallel(netD) self.netD = netD if self.device == 'cuda': self.netD.cuda() self.l1_loss = alpha_loss() self.c_loss = compose_loss() self.g_loss = alpha_gradient_loss() self.GAN_loss = GANloss() self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4) self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5) self.log_writer = SummaryWriter(scriptdir) self.model_dir = scriptdir self._maybe_trace() def _maybe_trace(self): for data in self.train_data: bg, image, seg, multi_fr = data['bg'], data['image'], data['seg'], data['multi_fr'] if self.device == 'cuda': bg, image, seg, multi_fr = Variable(bg.cuda()), Variable( image.cuda()), Variable(seg.cuda()), Variable(multi_fr.cuda()) else: bg, image, seg, multi_fr = Variable(bg), Variable( image), Variable(seg), Variable(multi_fr) if self.jit: self.netB = torch.jit.trace( self.netB, (image, bg, seg, multi_fr)) self.netG = torch.jit.trace( self.netG, (image, bg, seg, multi_fr)) else: self.netB(image, bg, seg, multi_fr) self.netG(image, bg, seg, multi_fr) break def get_module(self): # use netG (generation) for the return module for _i, data in enumerate(self.train_data): bg, image, seg, multi_fr, seg_gt, back_rnd = data['bg'], data[ 'image'], data['seg'], data['multi_fr'], data['seg-gt'], data['back-rnd'] return self.netG, (image.to(self.device), bg.to(self.device), seg.to(self.device), multi_fr.to(self.device)) # eval() isn't implemented # train() is on by default def _set_mode(self, train): pass def train(self, niter=1): self.netG.train() self.netD.train() lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 t0 = time.time() KK = len(self.train_data) wt = 1 epoch = 0 step = 50 for i, data in enumerate(self.train_data): if (i > niter): break # Initiating bg, image, seg, multi_fr, seg_gt, back_rnd = data['bg'], data[ 'image'], data['seg'], data['multi_fr'], data['seg-gt'], data['back-rnd'] if self.device == 'cuda': bg, image, seg, multi_fr, seg_gt, back_rnd = Variable(bg.cuda()), Variable(image.cuda()), Variable( seg.cuda()), Variable(multi_fr.cuda()), Variable(seg_gt.cuda()), Variable(back_rnd.cuda()) mask0 = Variable(torch.ones(seg.shape).cuda()) else: bg, image, seg, multi_fr, seg_gt, back_rnd = Variable(bg), Variable( image), Variable(seg), Variable(multi_fr), Variable(seg_gt), Variable(back_rnd) mask0 = Variable(torch.ones(seg.shape)) tr0 = time.time() # pseudo-supervision alpha_pred_sup, fg_pred_sup = self.netB(image, bg, seg, multi_fr) if self.device == 'cuda': mask = (alpha_pred_sup > -0.98).type(torch.cuda.FloatTensor) mask1 = (seg_gt > 0.95).type(torch.cuda.FloatTensor) else: mask = (alpha_pred_sup > -0.98).type(torch.FloatTensor) mask1 = (seg_gt > 0.95).type(torch.FloatTensor) # Train Generator alpha_pred, fg_pred = self.netG(image, bg, seg, multi_fr) # pseudo-supervised losses al_loss = self.l1_loss(alpha_pred_sup, alpha_pred, mask0) + \ 0.5 * self.g_loss(alpha_pred_sup, alpha_pred, mask0) fg_loss = self.l1_loss(fg_pred_sup, fg_pred, mask) # compose into same background comp_loss = self.c_loss(image, alpha_pred, fg_pred, bg, mask1) # randomly permute the background perm = torch.LongTensor(np.random.permutation(bg.shape[0])) bg_sh = bg[perm, :, :, :] if self.device == 'cuda': al_mask = (alpha_pred > 0.95).type(torch.cuda.FloatTensor) else: al_mask = (alpha_pred > 0.95).type(torch.FloatTensor) # Choose the target background for composition # back_rnd: contains separate set of background videos captured # bg_sh: contains randomly permuted captured background from the same minibatch if np.random.random_sample() > 0.5: bg_sh = back_rnd image_sh = compose_image_withshift( alpha_pred, image*al_mask + fg_pred*(1-al_mask), bg_sh, seg) fake_response = self.netD(image_sh) loss_ganG = self.GAN_loss(fake_response, label_type=True) lossG = loss_ganG + wt*(0.05*comp_loss+0.05*al_loss+0.05*fg_loss) self.optimizerG.zero_grad() lossG.backward() self.optimizerG.step() # Train Discriminator fake_response = self.netD(image_sh) real_response = self.netD(image) loss_ganD_fake = self.GAN_loss(fake_response, label_type=False) loss_ganD_real = self.GAN_loss(real_response, label_type=True) lossD = (loss_ganD_real+loss_ganD_fake)*0.5 # Update discriminator for every 5 generator update if i % 5 == 0: self.optimizerD.zero_grad() lossD.backward() self.optimizerD.step() lG += lossG.data lD += lossD.data GenL += loss_ganG.data DisL_r += loss_ganD_real.data DisL_f += loss_ganD_fake.data alL += al_loss.data fgL += fg_loss.data compL += comp_loss.data self.log_writer.add_scalar( 'Generator Loss', lossG.data, epoch*KK + i + 1) self.log_writer.add_scalar('Discriminator Loss', lossD.data, epoch*KK + i + 1) self.log_writer.add_scalar('Generator Loss: Fake', loss_ganG.data, epoch*KK + i + 1) self.log_writer.add_scalar('Discriminator Loss: Real', loss_ganD_real.data, epoch*KK + i + 1) self.log_writer.add_scalar('Discriminator Loss: Fake', loss_ganD_fake.data, epoch*KK + i + 1) self.log_writer.add_scalar('Generator Loss: Alpha', al_loss.data, epoch*KK + i + 1) self.log_writer.add_scalar('Generator Loss: Fg', fg_loss.data, epoch*KK + i + 1) self.log_writer.add_scalar('Generator Loss: Comp', comp_loss.data, epoch*KK + i + 1) t1 = time.time() elapse += t1 - t0 elapse_run += t1-tr0 t0 = t1 if i % step == (step-1): print('[%d, %5d] Gen-loss: %.4f Disc-loss: %.4f Alpha-loss: %.4f Fg-loss: %.4f Comp-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' % (epoch + 1, i + 1, lG/step, lD/step, alL/step, fgL/step, compL/step, elapse/step, elapse_run/step)) lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 write_tb_log(image, 'image', self.log_writer, i) write_tb_log(seg, 'seg', self.log_writer, i) write_tb_log(alpha_pred_sup, 'alpha-sup', self.log_writer, i) write_tb_log(alpha_pred, 'alpha_pred', self.log_writer, i) write_tb_log(fg_pred_sup*mask, 'fg-pred-sup', self.log_writer, i) write_tb_log(fg_pred*mask, 'fg_pred', self.log_writer, i) # composition alpha_pred = (alpha_pred+1)/2 comp = fg_pred*alpha_pred + (1-alpha_pred)*bg write_tb_log(comp, 'composite-same', self.log_writer, i) write_tb_log(image_sh, 'composite-diff', self.log_writer, i) del comp del mask, back_rnd, mask0, seg_gt, mask1, bg, alpha_pred, alpha_pred_sup, image, fg_pred_sup, fg_pred, seg, multi_fr, image_sh, bg_sh, fake_response, real_response, al_loss, fg_loss, comp_loss, lossG, lossD, loss_ganD_real, loss_ganD_fake, loss_ganG if (epoch % 2 == 0): torch.save(self.netG.state_dict(), os.path.join(self.model_dir, 'netG_epoch_%d.pth' % (epoch))) torch.save(self.optimizerG.state_dict(), os.path.join(self.model_dir, 'optimG_epoch_%d.pth' % (epoch))) torch.save(self.netD.state_dict(), os.path.join(self.model_dir, 'netD_epoch_%d.pth' % (epoch))) torch.save(self.optimizerD.state_dict(), os.path.join(self.model_dir, 'optimD_epoch_%d.pth' % (epoch))) # Change weight every 2 epoch to put more stress on discriminator weight and less on pseudo-supervision wt = wt/2 def eval(self, niter=1): raise NotImplementedError()