def train()

in torchbenchmark/models/Background_Matting/__init__.py [0:0]


    def train(self, niter=1):
        self.netG.train()
        self.netD.train()
        lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0
        t0 = time.time()
        KK = len(self.train_data)
        wt = 1
        epoch = 0
        step = 50

        for i, data in enumerate(self.train_data):
            if (i > niter):
                break
            # Initiating

            bg, image, seg, multi_fr, seg_gt, back_rnd = data['bg'], data[
                'image'], data['seg'], data['multi_fr'], data['seg-gt'], data['back-rnd']

            if self.device == 'cuda':
                bg, image, seg, multi_fr, seg_gt, back_rnd = Variable(bg.cuda()), Variable(image.cuda()), Variable(
                    seg.cuda()), Variable(multi_fr.cuda()), Variable(seg_gt.cuda()), Variable(back_rnd.cuda())
                mask0 = Variable(torch.ones(seg.shape).cuda())
            else:
                bg, image, seg, multi_fr, seg_gt, back_rnd = Variable(bg), Variable(
                    image), Variable(seg), Variable(multi_fr), Variable(seg_gt), Variable(back_rnd)
                mask0 = Variable(torch.ones(seg.shape))

            tr0 = time.time()

            # pseudo-supervision
            alpha_pred_sup, fg_pred_sup = self.netB(image, bg, seg, multi_fr)
            if self.device == 'cuda':
                mask = (alpha_pred_sup > -0.98).type(torch.cuda.FloatTensor)
                mask1 = (seg_gt > 0.95).type(torch.cuda.FloatTensor)
            else:
                mask = (alpha_pred_sup > -0.98).type(torch.FloatTensor)
                mask1 = (seg_gt > 0.95).type(torch.FloatTensor)

            # Train Generator

            alpha_pred, fg_pred = self.netG(image, bg, seg, multi_fr)

            # pseudo-supervised losses
            al_loss = self.l1_loss(alpha_pred_sup, alpha_pred, mask0) + \
                0.5 * self.g_loss(alpha_pred_sup, alpha_pred, mask0)
            fg_loss = self.l1_loss(fg_pred_sup, fg_pred, mask)

            # compose into same background
            comp_loss = self.c_loss(image, alpha_pred, fg_pred, bg, mask1)

            # randomly permute the background
            perm = torch.LongTensor(np.random.permutation(bg.shape[0]))
            bg_sh = bg[perm, :, :, :]

            if self.device == 'cuda':
                al_mask = (alpha_pred > 0.95).type(torch.cuda.FloatTensor)
            else:
                al_mask = (alpha_pred > 0.95).type(torch.FloatTensor)

            # Choose the target background for composition
            # back_rnd: contains separate set of background videos captured
            # bg_sh: contains randomly permuted captured background from the same minibatch
            if np.random.random_sample() > 0.5:
                bg_sh = back_rnd

            image_sh = compose_image_withshift(
                alpha_pred, image*al_mask + fg_pred*(1-al_mask), bg_sh, seg)

            fake_response = self.netD(image_sh)

            loss_ganG = self.GAN_loss(fake_response, label_type=True)

            lossG = loss_ganG + wt*(0.05*comp_loss+0.05*al_loss+0.05*fg_loss)

            self.optimizerG.zero_grad()

            lossG.backward()
            self.optimizerG.step()

            # Train Discriminator

            fake_response = self.netD(image_sh)
            real_response = self.netD(image)

            loss_ganD_fake = self.GAN_loss(fake_response, label_type=False)
            loss_ganD_real = self.GAN_loss(real_response, label_type=True)

            lossD = (loss_ganD_real+loss_ganD_fake)*0.5

            # Update discriminator for every 5 generator update
            if i % 5 == 0:
                self.optimizerD.zero_grad()
                lossD.backward()
                self.optimizerD.step()

            lG += lossG.data
            lD += lossD.data
            GenL += loss_ganG.data
            DisL_r += loss_ganD_real.data
            DisL_f += loss_ganD_fake.data

            alL += al_loss.data
            fgL += fg_loss.data
            compL += comp_loss.data

            self.log_writer.add_scalar(
                'Generator Loss', lossG.data, epoch*KK + i + 1)
            self.log_writer.add_scalar('Discriminator Loss',
                                       lossD.data, epoch*KK + i + 1)
            self.log_writer.add_scalar('Generator Loss: Fake',
                                       loss_ganG.data, epoch*KK + i + 1)
            self.log_writer.add_scalar('Discriminator Loss: Real',
                                       loss_ganD_real.data, epoch*KK + i + 1)
            self.log_writer.add_scalar('Discriminator Loss: Fake',
                                       loss_ganD_fake.data, epoch*KK + i + 1)

            self.log_writer.add_scalar('Generator Loss: Alpha',
                                       al_loss.data, epoch*KK + i + 1)
            self.log_writer.add_scalar('Generator Loss: Fg',
                                       fg_loss.data, epoch*KK + i + 1)
            self.log_writer.add_scalar('Generator Loss: Comp',
                                       comp_loss.data, epoch*KK + i + 1)

            t1 = time.time()

            elapse += t1 - t0
            elapse_run += t1-tr0
            t0 = t1

            if i % step == (step-1):
                print('[%d, %5d] Gen-loss:  %.4f Disc-loss: %.4f Alpha-loss: %.4f Fg-loss: %.4f Comp-loss: %.4f Time-all: %.4f Time-fwbw: %.4f' %
                      (epoch + 1, i + 1, lG/step, lD/step, alL/step, fgL/step, compL/step, elapse/step, elapse_run/step))
                lG, lD, GenL, DisL_r, DisL_f, alL, fgL, compL, elapse_run, elapse = 0, 0, 0, 0, 0, 0, 0, 0, 0, 0

                write_tb_log(image, 'image', self.log_writer, i)
                write_tb_log(seg, 'seg', self.log_writer, i)
                write_tb_log(alpha_pred_sup, 'alpha-sup', self.log_writer, i)
                write_tb_log(alpha_pred, 'alpha_pred', self.log_writer, i)
                write_tb_log(fg_pred_sup*mask, 'fg-pred-sup',
                             self.log_writer, i)
                write_tb_log(fg_pred*mask, 'fg_pred', self.log_writer, i)

                # composition
                alpha_pred = (alpha_pred+1)/2
                comp = fg_pred*alpha_pred + (1-alpha_pred)*bg
                write_tb_log(comp, 'composite-same', self.log_writer, i)
                write_tb_log(image_sh, 'composite-diff', self.log_writer, i)

                del comp

            del mask, back_rnd, mask0, seg_gt, mask1, bg, alpha_pred, alpha_pred_sup, image, fg_pred_sup, fg_pred, seg, multi_fr, image_sh, bg_sh, fake_response, real_response, al_loss, fg_loss, comp_loss, lossG, lossD, loss_ganD_real, loss_ganD_fake, loss_ganG

        if (epoch % 2 == 0):
            torch.save(self.netG.state_dict(),
                       os.path.join(self.model_dir, 'netG_epoch_%d.pth' % (epoch)))
            torch.save(self.optimizerG.state_dict(),
                       os.path.join(self.model_dir, 'optimG_epoch_%d.pth' % (epoch)))
            torch.save(self.netD.state_dict(),
                       os.path.join(self.model_dir, 'netD_epoch_%d.pth' % (epoch)))
            torch.save(self.optimizerD.state_dict(),
                       os.path.join(self.model_dir, 'optimD_epoch_%d.pth' % (epoch)))

            # Change weight every 2 epoch to put more stress on discriminator weight and less on pseudo-supervision
            wt = wt/2