in torchbenchmark/models/Background_Matting/__init__.py [0:0]
def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)
self.opt = Namespace(**{
'n_blocks1': 7,
'n_blocks2': 3,
'batch_size': self.batch_size,
'resolution': 512,
'name': 'Real_fixed'
})
scriptdir = os.path.dirname(os.path.realpath(__file__))
csv_file_path = _create_data_dir().joinpath("Video_data_train_processed.csv")
root = str(Path(__file__).parent)
with open(f"{root}/Video_data_train.csv", "r") as r:
with open(csv_file_path, "w") as w:
w.write(r.read().format(scriptdir=scriptdir))
data_config_train = {
'reso': (self.opt.resolution, self.opt.resolution)}
traindata = VideoData(csv_file=csv_file_path,
data_config=data_config_train, transform=None)
train_loader = torch.utils.data.DataLoader(
traindata, batch_size=self.opt.batch_size, shuffle=True, num_workers=0, collate_fn=_collate_filter_none)
self.train_data = []
for data in train_loader:
self.train_data.append(data)
if device == 'cuda':
for key in data:
data[key].cuda()
netB = ResnetConditionHR(input_nc=(
3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2)
if self.device == 'cuda':
netB.cuda()
netB.eval()
for param in netB.parameters(): # freeze netB
param.requires_grad = False
self.netB = netB
netG = ResnetConditionHR(input_nc=(
3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2)
netG.apply(conv_init)
self.netG = netG
if self.device == 'cuda':
self.netG.cuda()
# TODO(asuhan): is this needed?
torch.backends.cudnn.benchmark = True
netD = MultiscaleDiscriminator(
input_nc=3, num_D=1, norm_layer=nn.InstanceNorm2d, ndf=64)
netD.apply(conv_init)
# netD = nn.DataParallel(netD)
self.netD = netD
if self.device == 'cuda':
self.netD.cuda()
self.l1_loss = alpha_loss()
self.c_loss = compose_loss()
self.g_loss = alpha_gradient_loss()
self.GAN_loss = GANloss()
self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5)
self.log_writer = SummaryWriter(scriptdir)
self.model_dir = scriptdir
self._maybe_trace()