def forward()

in model/transformer_gan.py [0:0]


    def forward(self, data, target, reset_mems, train_loss, mems=None, status_vec=None,
                update_D0=False):
        # loss type can be "mle","gen_loss", "dis_loss" or "mle_and_gen_loss"

        return_dict = {"mle": None, "gen_loss": None, "dis_loss": None, "mems": None}

        if "mle" in train_loss:
            ret = self.generator(data, target, reset_mems, mems, status_vec=status_vec)
            return_dict["mle"], return_dict["mems"] = ret

        # Sample a sequence
        if "gen" in train_loss or "dis" in train_loss or "classifier" in train_loss:
            # TODO: low priority could potentially make forward_generate a static func?
            # Cache params
            cache_tgt_len, cache_mem_len = (
                self.generator.tgt_len, self.generator.mem_len
            )

            # Reset params for sampling
            self.generator.reset_length(
                1, self.cfg.DISCRIMINATOR.mem_len
            )  # Use mem_len=bert_len

            # Generate samples
            # First token has to be expanded into one-hot if data is in index form

            # sample_mems can be greater than dis mem_len but only for one forward generate pass

            def process_for_sequence(inp):
                if len(inp.shape) == 1:
                    # Use F.one_hot(,self.n_token)
                    return (
                        inp.new_zeros(
                            (*inp.shape, self.ntokens), dtype=torch.float32
                        ).scatter_(-1, inp[..., None], 1)
                    )
                elif len(inp.shape) == 2:
                    return (inp)
                else:
                    raise NotImplementedError

            seq = []
            sample_mems = None
            # TODO: When training gen do not pass only context into dis (since no grads anyway)
            # TODO: do not loop over context

            status_vec = None
            with torch.no_grad():
                # Last token in context is used to feed forward for sequential generation
                if self.cfg.DISCRIMINATOR.context_len > 1:
                    context = data[:self.cfg.DISCRIMINATOR.context_len - 1]

                    if self.cfg.TRAIN.append_note_status:
                        bptt, batch_size = context.shape
                        status_vec = context.new_zeros((bptt, batch_size, self.vec_len), dtype=torch.bool)
                        self.vocab.update_status_vec(context, status_vec)
                    ret = self.generator.forward_generate(context, sample_mems,
                                                          status_vec=status_vec)
                    _, sample_mems  = ret

            sample_len = self.cfg.DISCRIMINATOR.tgt_len // self.cfg.DISCRIMINATOR.sample_chunks_mem

            # Do not detach the gradient w.r.t memory. Choose to manually detach
            self.generator.detach_mems_grad = False
            gen_loss, dis_loss, gp_loss = 0, 0, 0

            # Split real and fake samples according to sample_chunks_mem
            for chunk_start in range(0, self.cfg.DISCRIMINATOR.tgt_len, sample_len):
                chunk_end = min(chunk_start + sample_len, self.cfg.DISCRIMINATOR.tgt_len)
                # TODO: Can we retain sub graph after calling backward?
                for ind in range(chunk_start, chunk_end):

                    if ind < self.cfg.DISCRIMINATOR.context_len:
                        seq.append(process_for_sequence(data[ind]))
                        continue
                    # Since start token is chosen and bert tgt len is fixed
                    elif self.cfg.DISCRIMINATOR.truncate_backprop or (ind == chunk_start) or "classifier" in train_loss:
                        # Noticed I do not gain much memory if dis is frozen since requires_grad=False for dis params
                        # Also do not gain memory is dis is freed (Assuming pytorch already optimizes and shares comp
                        # graphs)
                        # Saved some time in backward()

                        # Stop gradient propagation
                        inp = torch.argmax(seq[-1], dim=-1)[None,
                              :].detach()  # Gumbel max so gradients do not propagate
                        cont = inp
                    else:
                        inp = seq[-1][None, :]
                        cont = torch.argmax(seq[-1], dim=-1)[None, :].detach()

                    if self.cfg.TRAIN.append_note_status:
                        bptt, batch_size = cont.shape
                        if status_vec is None:
                            status_vec = inp.new_zeros((bptt, batch_size, self.vec_len), dtype=torch.bool)
                        else:
                            status_vec = status_vec[-1:, :, :]
                        self.vocab.update_status_vec(cont, status_vec)

                    ret = self.generator.forward_generate_gumbel(inp, self.temperature, sample_mems, status_vec=status_vec)

                    logits, sample_mems = ret

                    seq.append(logits[0])

                # Prepare fake data

                # Ignore first token
                if len(seq) == sample_len + 1:
                    seq = seq[1:]

                fake_chunk = torch.cat(
                    [i[None, :, :] for i in seq], 0
                )  # seq_len, bsz, vocab

                if 'dis' in train_loss:
                    fake_chunk = fake_chunk.detach()

                data_chunk = data[chunk_start:chunk_end]

                if "classifier" in train_loss:
                    if self.P0 is None:
                        with torch.no_grad():
                            D0 = torch.sigmoid(self.dis_D_forward(fake_chunk))
                            self.P0 = (1. - D0) / torch.clamp(D0, min=1e-7)

                    real_label = self.P0.new_full((self.P0.shape[0],), 1.)
                    fake_label = self.P0.new_full((self.P0.shape[0],), 0.)

                    criterion = nn.BCELoss()
                    errDD_real = criterion(torch.sigmoid(self.dis_D_forward(data_chunk)), real_label)
                    errDD_fake = criterion(torch.sigmoid(self.dis_D_forward(fake_chunk.detach())), fake_label)
                    errDD_loss = errDD_real + errDD_fake

                    ((errDD_loss.float().mean()) / (
                                self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()


                    # Reset params for next chunk
                    sample_mems = sample_mems.detach()
                    seq = [seq[-1]]

                    continue

                # Train classifier
                if 'gen' in train_loss and (
                        'ppo' in self.cfg.DISCRIMINATOR.BERT.loss_type or 'ppo' in self.cfg.DISCRIMINATOR.CNN.loss_type):
                    if self.P0 is None or update_D0:
                        with torch.no_grad():
                            D0 = torch.sigmoid(self.dis_D_forward(fake_chunk))
                            self.P0 = (1. - D0) / torch.clamp(D0, min=1e-7)

                    D1 = torch.sigmoid(self.dis_D_forward(fake_chunk))
                    P1 = (1. - D1)
                    ratio = (P1 / torch.clamp(D1 * self.P0, min=1e-7))

                    ratio_clipped = torch.clamp(ratio, 1.0 - self.cfg.PPO.clip_param, 1.0 +
                                                self.cfg.PPO.clip_param)


                if self.cfg.DISCRIMINATOR.type == "bert": \
                        # bert_vocab_size = 311

                    data_chunk = torch.transpose(data_chunk, 0, 1)
                    fake_chunk = torch.transpose(fake_chunk, 0, 1)

                    # Pad zeros corresponding to MASK token
                    fake_chunk = torch.cat(
                        [fake_chunk, fake_chunk.new_zeros((*fake_chunk.shape[:-1], 1))], -1
                    )


                    embedding_matrix = self.discriminator.bert.embeddings.word_embeddings.weight

                    emb_real = embedding_matrix[data_chunk]
                    emb_fake = torch.einsum(
                        "ve,bcv -> bce",
                        embedding_matrix,
                        fake_chunk,
                    )

                    bert_emb_real = self.discriminator(inputs_embeds=emb_real)
                    bert_emb_fake = self.discriminator(inputs_embeds=emb_fake)

                    # 1 is real and 0 is fake in bce_loss, so take we extract the real index from output vector
                    d_out_real, d_out_fake = bert_emb_real[0][:, 0], bert_emb_fake[0][:, 0]

                    if 'ppo' in self.cfg.DISCRIMINATOR.BERT.loss_type and 'gen' in train_loss:
                        surr1 = ratio * d_out_fake
                        surr2 = ratio_clipped * d_out_fake
                        target = torch.where(d_out_fake > 0, torch.min(surr1, surr2), torch.max(surr1, surr2))
                        temp_gen_loss, temp_dis_loss = get_losses(d_out_real, target,
                                                                  self.cfg.DISCRIMINATOR.BERT.loss_type)
                    else:
                        temp_gen_loss, temp_dis_loss = get_losses(d_out_real, d_out_fake,
                                                                  self.cfg.DISCRIMINATOR.BERT.loss_type)

                    # Regularize discriminator with gradient penalty
                    if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
                        data_chunk = (
                            data_chunk.new_zeros((*data_chunk.shape, self.ntokens + 1), dtype=torch.float32)
                                .scatter_(-1, data_chunk[..., None], 1)
                        )
                        temp_gp_loss = self.calc_gradient_penalty(data_chunk, fake_chunk)

                    if self.cfg.DISCRIMINATOR.backprop_outside:
                        gen_loss += temp_gen_loss.detach()
                        dis_loss += temp_dis_loss.detach()
                        if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
                            gp_loss += temp_gp_loss.detach()
                    else:
                        gen_loss += temp_gen_loss
                        dis_loss += temp_dis_loss
                        if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
                            gp_loss += temp_gp_loss

                elif self.cfg.DISCRIMINATOR.type == "cnn":
                    real_samples = (
                        data_chunk.new_zeros((*data_chunk.shape, self.ntokens), dtype=torch.float32)
                            .scatter_(-1, data_chunk[..., None], 1)
                            .transpose(0, 1)
                    )
                    gen_samples = torch.transpose(fake_chunk, 0, 1)

                    d_out_real = self.discriminator(real_samples)
                    d_out_fake = self.discriminator(gen_samples)

                    if 'ppo' in self.cfg.DISCRIMINATOR.CNN.loss_type and 'gen' in train_loss:
                        surr1 = ratio * d_out_fake
                        surr2 = ratio_clipped * d_out_fake
                        target = torch.where(d_out_fake > 0, torch.min(surr1, surr2), torch.max(surr1, surr2))
                        temp_gen_loss, temp_dis_loss = get_losses(d_out_real, target,
                                                                  self.cfg.DISCRIMINATOR.CNN.loss_type)
                    else:
                        temp_gen_loss, temp_dis_loss = get_losses(
                            d_out_real, d_out_fake, self.cfg.DISCRIMINATOR.CNN.loss_type
                        )
                    # Regularize discriminator with gradient penalty
                    if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
                        temp_gp_loss = self.calc_gradient_penalty(real_samples, gen_samples)

                    if self.cfg.DISCRIMINATOR.backprop_outside:
                        gen_loss += temp_gen_loss.detach()
                        dis_loss += temp_dis_loss.detach()
                        if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
                            gp_loss += temp_gp_loss.detach()
                    else:
                        gen_loss += temp_gen_loss
                        dis_loss += temp_dis_loss
                        if "dis" in train_loss and 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
                            gp_loss += temp_gp_loss

                else:
                    raise NotImplementedError


                if self.cfg.DISCRIMINATOR.backprop_outside:
                    if "gen" in train_loss:
                        ((temp_gen_loss.float().mean()) * self.cfg.DISCRIMINATOR.gen_loss_factor / (
                                self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()
                    if "dis" in train_loss:
                        ((temp_dis_loss.float().mean()) * self.cfg.DISCRIMINATOR.dis_loss_factor / (
                                self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()

                        if self.cfg.DISCRIMINATOR.type == "bert":
                            if 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
                                ((temp_gp_loss.float().mean()) * self.cfg.DISCRIMINATOR.dis_loss_factor / (
                                        self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()
                        elif self.cfg.DISCRIMINATOR.type == "cnn":
                            if 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
                                ((temp_gp_loss.float().mean()) * self.cfg.DISCRIMINATOR.dis_loss_factor / (
                                        self.cfg.DISCRIMINATOR.batch_chunk * self.cfg.DISCRIMINATOR.sample_chunks_mem)).backward()  # TODO CNN WGAN-GP
                        else:
                            raise NotImplementedError

                # Reset params for next chunk
                sample_mems = sample_mems.detach()
                seq = [seq[-1]]

            # Reset model parameters
            self.generator.detach_mems_grad = True
            self.generator.reset_length(cache_tgt_len, cache_mem_len)

            #Setup values to return
            if "dis" in train_loss:
                dis_loss = self.cfg.DISCRIMINATOR.dis_loss_factor * dis_loss / self.cfg.DISCRIMINATOR.sample_chunks_mem
                return_dict["dis_loss"] = dis_loss
                if self.cfg.DISCRIMINATOR.type == "bert":
                    if 'gp' in self.cfg.DISCRIMINATOR.BERT.loss_type:
                        return_dict[
                            "gp_loss"] = self.cfg.DISCRIMINATOR.dis_loss_factor * gp_loss / self.cfg.DISCRIMINATOR.sample_chunks_mem
                elif self.cfg.DISCRIMINATOR.type == "cnn":
                    if 'gp' in self.cfg.DISCRIMINATOR.CNN.loss_type:
                        return_dict[
                            "gp_loss"] = self.cfg.DISCRIMINATOR.dis_loss_factor * gp_loss / self.cfg.DISCRIMINATOR.sample_chunks_mem
                else:
                    raise NotImplementedError

            elif "gen" in train_loss:
                gen_loss = self.cfg.DISCRIMINATOR.gen_loss_factor * gen_loss / self.cfg.DISCRIMINATOR.sample_chunks_mem
                return_dict["gen_loss"] = gen_loss

        return return_dict