dib/transformers/ib/dib.py [734:754]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
            )

        if self.conditional == "H_Q'[X|Z,Y]":
            target = (
                extract_target(targets, self.map_target_position)
                .unsqueeze(0)
                .unsqueeze(-1)
                .float()
            )
            target = torch.repeat_interleave(target, z_sample.size(0), dim=0)
            z_sample = torch.cat([z_sample, target], dim=-1)

        try:
            zx_loss = self.compute_zx_loss(z_sample, targets)
        except NotEnoughHeads as e:
            # if not training then don't raise exception (because the indexing might be off in which
            # case your predictor cannot comoute zx_loss). But you don't want to never compute this
            # loss as for evaluation we give the training data but self.training=False
            if self.training:
                raise e
            zx_loss = 0
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



dib/transformers/ib/dib.py [789:809]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
        )  # when only computing DIQ the samples will be squeezed

        if self.conditional == "H_Q'[X|Z,Y]":
            target = (
                extract_target(targets, self.map_target_position)
                .unsqueeze(0)
                .unsqueeze(-1)
                .float()
            )
            target = torch.repeat_interleave(target, z_sample.size(0), dim=0)
            z_sample = torch.cat([z_sample, target], dim=-1)

        try:
            zx_loss = self.compute_zx_loss(z_sample, targets)
        except NotEnoughHeads as e:
            # if not training then don't raise exception (because the indexing might be off in which
            # case your predictor cannot comoute zx_loss). But you don't want to never compute this
            # loss as for evaluation we give the training data but self.training=False
            if self.training:
                raise e
            zx_loss = 0
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



