def forward()

in dib/transformers/ib/dib.py [0:0]


    def forward(self, out, targets):

        if self.training:
            self._counter_warm_Q_zx += 1

        y_pred, z_sample, p_zCx = out

        if p_zCx.out is not None:
            p_zCx_base = p_zCx.out.base_dist
            # z_norm : elementwise square, z_mean_norm : mean of absolute val, z_std : mean of standard dev
            self._store(
                z_norm=z_sample.pow(2).mean(),
                z_mean_norm=p_zCx_base.loc.abs().mean(),
                z_std=p_zCx_base.scale.mean(),
            )

        if self.conditional == "H_Q'[X|Z,Y]":
            target = (
                extract_target(targets, self.map_target_position)
                .unsqueeze(0)
                .unsqueeze(-1)
                .float()
            )
            target = torch.repeat_interleave(target, z_sample.size(0), dim=0)
            z_sample = torch.cat([z_sample, target], dim=-1)

        try:
            zx_loss = self.compute_zx_loss(z_sample, targets)
        except NotEnoughHeads as e:
            # if not training then don't raise exception (because the indexing might be off in which
            # case your predictor cannot comoute zx_loss). But you don't want to never compute this
            # loss as for evaluation we give the training data but self.training=False
            if self.training:
                raise e
            zx_loss = 0

        if self.weight_kl is not None:
            p_zCx = p_zCx.out
            mean_0 = torch.zeros_like(p_zCx.base_dist.loc)
            std_1 = torch.ones_like(p_zCx.base_dist.scale)
            p_z = MultivariateNormalDiag(mean_0, std_1)
            kl = kl_divergence(p_zCx, p_z).mean(0) / math.log(BASE_LOG)
            zx_loss = zx_loss + self.weight_kl * kl

        zy_loss = self.compute_zy_loss(y_pred, targets)

        if zy_loss > self.threshold_suff:  # DEV
            zx_loss = zx_loss * 0 + detach(zx_loss)

        if self._counter_warm_Q_zx <= self.warm_Q_zx:
            # still return loss but no grad
            zy_loss = zy_loss * 0 + detach(zy_loss)

        if self.is_zx_only:  # DEV
            zy_loss = 0 * zy_loss

        self._store(aux_loss=zx_loss)

        if not self.training:
            # when evaluating the loss should be log likelihood for checkpointing
            return zy_loss

        return zy_loss + zx_loss