def __init__()

in torchbenchmark/models/Background_Matting/__init__.py [0:0]


    def __init__(self, test, device, jit=False, batch_size=None, extra_args=[]):
        super().__init__(test=test, device=device, jit=jit, batch_size=batch_size, extra_args=extra_args)

        self.opt = Namespace(**{
            'n_blocks1': 7,
            'n_blocks2': 3,
            'batch_size': self.batch_size,
            'resolution': 512,
            'name': 'Real_fixed'
        })

        scriptdir = os.path.dirname(os.path.realpath(__file__))
        csv_file_path = _create_data_dir().joinpath("Video_data_train_processed.csv")
        root = str(Path(__file__).parent)
        with open(f"{root}/Video_data_train.csv", "r") as r:
            with open(csv_file_path, "w") as w:
                w.write(r.read().format(scriptdir=scriptdir))
        data_config_train = {
            'reso': (self.opt.resolution, self.opt.resolution)}
        traindata = VideoData(csv_file=csv_file_path,
                              data_config=data_config_train, transform=None)
        train_loader = torch.utils.data.DataLoader(
            traindata, batch_size=self.opt.batch_size, shuffle=True, num_workers=0, collate_fn=_collate_filter_none)
        self.train_data = []
        for data in train_loader:
            self.train_data.append(data)
            if device == 'cuda':
                for key in data:
                    data[key].cuda()

        netB = ResnetConditionHR(input_nc=(
            3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2)
        if self.device == 'cuda':
            netB.cuda()
        netB.eval()
        for param in netB.parameters():  # freeze netB
            param.requires_grad = False
        self.netB = netB

        netG = ResnetConditionHR(input_nc=(
            3, 3, 1, 4), output_nc=4, n_blocks1=self.opt.n_blocks1, n_blocks2=self.opt.n_blocks2)
        netG.apply(conv_init)
        self.netG = netG

        if self.device == 'cuda':
            self.netG.cuda()
            # TODO(asuhan): is this needed?
            torch.backends.cudnn.benchmark = True

        netD = MultiscaleDiscriminator(
            input_nc=3, num_D=1, norm_layer=nn.InstanceNorm2d, ndf=64)
        netD.apply(conv_init)
        # netD = nn.DataParallel(netD)
        self.netD = netD
        if self.device == 'cuda':
            self.netD.cuda()

        self.l1_loss = alpha_loss()
        self.c_loss = compose_loss()
        self.g_loss = alpha_gradient_loss()
        self.GAN_loss = GANloss()

        self.optimizerG = optim.Adam(netG.parameters(), lr=1e-4)
        self.optimizerD = optim.Adam(netD.parameters(), lr=1e-5)

        self.log_writer = SummaryWriter(scriptdir)
        self.model_dir = scriptdir

        self._maybe_trace()