in pyro/distributions/diag_normal_mixture.py [0:0]
def backward(ctx, grad_output):
z, scales, locs, logits, pis = ctx.saved_tensors
dim = scales.size(-1)
K = logits.size(-1)
g = grad_output # l b i
g = g.unsqueeze(-2) # l b 1 i
batch_dims = locs.dim() - 2
locs_tilde = locs / scales # b j i
sigma_0 = torch.min(scales, -2, keepdim=True)[0] # b 1 i
z_shift = (z.unsqueeze(-2) - locs) / sigma_0 # l b j i
z_tilde = z.unsqueeze(-2) / scales - locs_tilde # l b j i
mu_cd = locs.unsqueeze(-2) - locs.unsqueeze(-3) # b c d i
mu_cd_norm = torch.pow(mu_cd, 2.0).sum(-1).sqrt() # b c d
mu_cd /= mu_cd_norm.unsqueeze(-1) # b c d i
diagonals = torch.empty((K,), dtype=torch.long, device=z.device)
torch.arange(K, out=diagonals)
mu_cd[..., diagonals, diagonals, :] = 0.0
mu_ll_cd = (locs.unsqueeze(-2) * mu_cd).sum(-1) # b c d
z_ll_cd = (z.unsqueeze(-2).unsqueeze(-2) * mu_cd).sum(-1) # l b c d
z_perp_cd = z.unsqueeze(-2).unsqueeze(-2) - z_ll_cd.unsqueeze(-1) * mu_cd # l b c d i
z_perp_cd_sqr = torch.pow(z_perp_cd, 2.0).sum(-1) # l b c d
shift_indices = torch.empty((dim,), dtype=torch.long, device=z.device)
torch.arange(dim, out=shift_indices)
shift_indices = shift_indices - 1
shift_indices[0] = 0
z_shift_cumsum = torch.pow(z_shift, 2.0)
z_shift_cumsum = z_shift_cumsum.sum(-1, keepdim=True) - torch.cumsum(z_shift_cumsum, dim=-1) # l b j i
z_tilde_cumsum = torch.cumsum(torch.pow(z_tilde, 2.0), dim=-1) # l b j i
z_tilde_cumsum = torch.index_select(z_tilde_cumsum, -1, shift_indices)
z_tilde_cumsum[..., 0] = 0.0
r_sqr_ji = z_shift_cumsum + z_tilde_cumsum # l b j i
log_scales = torch.log(scales) # b j i
epsilons_sqr = torch.pow(z_tilde, 2.0) # l b j i
log_qs = -0.5 * epsilons_sqr - 0.5 * math.log(2.0 * math.pi) - log_scales # l b j i
log_q_j = log_qs.sum(-1, keepdim=True) # l b j 1
q_j = torch.exp(log_q_j) # l b j 1
q_tot = (pis * q_j.squeeze(-1)).sum(-1) # l b
q_tot = q_tot.unsqueeze(-1) # l b 1
root_two = math.sqrt(2.0)
shift_log_scales = log_scales[..., shift_indices]
shift_log_scales[..., 0] = 0.0
sigma_products = torch.cumsum(shift_log_scales, dim=-1).exp() # b j i
reverse_indices = torch.tensor(range(dim - 1, -1, -1), dtype=torch.long, device=z.device)
reverse_log_sigma_0 = sigma_0.log()[..., reverse_indices] # b 1 i
sigma_0_products = torch.cumsum(reverse_log_sigma_0, dim=-1).exp()[..., reverse_indices - 1] # b 1 i
sigma_0_products[..., -1] = 1.0
sigma_products *= sigma_0_products
logits_grad = torch.erf(z_tilde / root_two) - torch.erf(z_shift / root_two) # l b j i
logits_grad *= torch.exp(-0.5 * r_sqr_ji) # l b j i
logits_grad = (logits_grad * g / sigma_products).sum(-1) # l b j
logits_grad = sum_leftmost(logits_grad / q_tot, -1 - batch_dims) # b j
logits_grad *= 0.5 * math.pow(2.0 * math.pi, -0.5 * (dim - 1))
logits_grad = -pis * logits_grad
logits_grad = logits_grad - logits_grad.sum(-1, keepdim=True) * pis
mu_ll_dc = torch.transpose(mu_ll_cd, -1, -2)
v_cd = torch.erf((z_ll_cd - mu_ll_cd) / root_two) - torch.erf((z_ll_cd + mu_ll_dc) / root_two)
v_cd *= torch.exp(-0.5 * z_perp_cd_sqr) # l b c d
mu_cd_g = (g.unsqueeze(-2) * mu_cd).sum(-1) # l b c d
v_cd *= -mu_cd_g * pis.unsqueeze(-2) * 0.5 * math.pow(2.0 * math.pi, -0.5 * (dim - 1)) # l b c d
v_cd = pis * sum_leftmost(v_cd.sum(-1) / q_tot, -1 - batch_dims)
logits_grad += v_cd
prefactor = pis.unsqueeze(-1) * q_j * g / q_tot.unsqueeze(-1)
locs_grad = sum_leftmost(prefactor, -2 - batch_dims)
scales_grad = sum_leftmost(prefactor * z_tilde, -2 - batch_dims)
return locs_grad, scales_grad, logits_grad, None, None, None