in dib/transformers/ib/dib.py [0:0]
def forward(self, out, targets):
if self.training:
self._counter_warm_Q_zx += 1
y_pred, z_sample, p_zCx = out
if p_zCx.out is not None:
p_zCx_base = p_zCx.out.base_dist
# z_norm : elementwise square, z_mean_norm : mean of absolute val, z_std : mean of standard dev
self._store(
z_norm=z_sample.pow(2).mean(),
z_mean_norm=p_zCx_base.loc.abs().mean(),
z_std=p_zCx_base.scale.mean(),
)
if self.conditional == "H_Q'[X|Z,Y]":
target = (
extract_target(targets, self.map_target_position)
.unsqueeze(0)
.unsqueeze(-1)
.float()
)
target = torch.repeat_interleave(target, z_sample.size(0), dim=0)
z_sample = torch.cat([z_sample, target], dim=-1)
try:
zx_loss = self.compute_zx_loss(z_sample, targets)
except NotEnoughHeads as e:
# if not training then don't raise exception (because the indexing might be off in which
# case your predictor cannot comoute zx_loss). But you don't want to never compute this
# loss as for evaluation we give the training data but self.training=False
if self.training:
raise e
zx_loss = 0
if self.weight_kl is not None:
p_zCx = p_zCx.out
mean_0 = torch.zeros_like(p_zCx.base_dist.loc)
std_1 = torch.ones_like(p_zCx.base_dist.scale)
p_z = MultivariateNormalDiag(mean_0, std_1)
kl = kl_divergence(p_zCx, p_z).mean(0) / math.log(BASE_LOG)
zx_loss = zx_loss + self.weight_kl * kl
zy_loss = self.compute_zy_loss(y_pred, targets)
if zy_loss > self.threshold_suff: # DEV
zx_loss = zx_loss * 0 + detach(zx_loss)
if self._counter_warm_Q_zx <= self.warm_Q_zx:
# still return loss but no grad
zy_loss = zy_loss * 0 + detach(zy_loss)
if self.is_zx_only: # DEV
zy_loss = 0 * zy_loss
self._store(aux_loss=zx_loss)
if not self.training:
# when evaluating the loss should be log likelihood for checkpointing
return zy_loss
return zy_loss + zx_loss