in probabilisticattention.py [0:0]
def forward(self, q, zeta, alpha, mu, beta, pi=None, v_init=None, v_fixed=None,
zeta_prior_precision=None, mu_prior_precision=None,
q_pos_emb=None, zeta_pos_emb=None, v_pos_emb=None, nonzero_wts_mask=None):
"""
Runs an update of the probabilistic version of attention based on a Mixture of Gaussians model.
This layer is equivalent to a standard dot product attention when:
self.uniform_query_precision = True
self.uniform_value_precision = True
sef.magnitude_priors = True
alpha = 1/sqrt(C) (Could be a scalar to save some memory)
beta = 0 (Could be a scalar to save some memory)
v_init = None
v_fixed = None
:param q: A tensor of queries with dims N, G, C, H
:param zeta: A tensor of keys (query/key Gaussian means) with dims N, G, C, H
:param alpha: A scalar (see special case above) or tensor of query/key Gaussian precisions with dims N, G, C, H
:param mu: A tensor of value Gaussian means with dims N, G, Cv, H
:param beta: A scalar (see special case above) or tensor of value Gaussian precisions with dims N, G, C, H
:param pi: A tensor of mixture component priors with dims N, G, H, H
:param v_init: A tensor of initial vals for the values with dims N, G, Cv, H (optional)
:param v_fixed: A tensor of fixed vals for the values with dims N, G, (Cv+1), H (optional). The extra (last) channel is an indicator for the fixed val locations
:param zeta_prior_precision: A tensor of precisions for the Gaussian prior over zeta with dims N, G, C, H (optional)
:param mu_prior_precision: A tensor of precisions for the Gaussian prior over mu with dims N, G, Cv, H (optional)
:param q_pos_emb: A tensor of query positional embeddings with dims C, H, H
:param zeta_pos_emb: A tensor of key positional embeddings with dims C, H, H
:param v_pos_emb: A tensor of value positional embeddings with dims Cv, H, H
:param nonzero_wts_mask: A boolean indexing tensor for setting weight matrix values to zero (where mask value is false) with dims H, H
:return: Updated values with dims N, G, Cv, H if no position embedding (v_pos_emb=None) else N, G, 2*Cv, H
"""
N, G, C_qk, H = q.shape
C_v = mu.shape[-2]
def update_weights():
q_2 = torch.sum(q**2, dim=-2) #torch.sum(torch.square(q), dim=-2)
zeta_2 = torch.sum(zeta**2, dim=-2) #torch.sum(torch.square(zeta), dim=-2)
q_zeta = torch.einsum('bgci, bgcj->bgij', q, zeta)
#q_m_zeta = q_2.unsqueeze(-1) + zeta_2.unsqueeze(-2) - 2 * q_zeta
log_p_q_v = q_2.unsqueeze(-1) + zeta_2.unsqueeze(-2) - 2 * q_zeta
if q_pos_emb is not None:
q_pos_emb_2 = torch.sum(q_pos_emb**2, dim=0)
q_q_pos_emb = torch.einsum('bgci, cij->bgij', q, q_pos_emb)
#q_m_q_pos_emb = q_2.unsqueeze(-1) + q_pos_emb_2.unsqueeze(0).unsqueeze(0) - 2 * q_q_pos_emb
#q_m_zeta += q_m_q_pos_emb
log_p_q_v += q_2.unsqueeze(-1) + q_pos_emb_2.unsqueeze(0).unsqueeze(0) - 2 * q_q_pos_emb
if zeta_pos_emb is not None:
zeta_pos_emb_2 = torch.sum(zeta_pos_emb ** 2, dim=0).transpose(0, 1)
zeta_zeta_pos_emb = torch.einsum('bgci, cij->bgij', zeta, zeta_pos_emb).transpose(2, 3)
#zeta_m_zeta_pos_emb = zeta_2.unsqueeze(-2) + zeta_pos_emb_2.unsqueeze(0).unsqueeze(0) - 2 * zeta_zeta_pos_emb
#q_m_zeta += zeta_m_zeta_pos_emb
log_p_q_v += zeta_2.unsqueeze(-2) + zeta_pos_emb_2.unsqueeze(0).unsqueeze(0) - 2 * zeta_zeta_pos_emb
if self.uniform_query_precision:
#log_p_q = -0.5 * alpha * q_m_zeta
log_p_q_v = -0.5 * alpha * log_p_q_v
else:
#log_p_q = -0.5 * alpha.unsqueeze(-2) * q_m_zeta
log_p_q_v = -0.5 * alpha.unsqueeze(-2) * log_p_q_v
#log_p_v = 0
mu_2 = torch.sum(mu**2, dim=-2) #torch.sum(torch.square(mu), dim=-2)
if v_init is not None:
v_init_2 = torch.sum(v_init**2, dim=-2) #torch.sum(torch.square(v_init), dim=-2)
v_init_mu = torch.einsum('bgci, bgcj->bgij', v_init, mu)
#v_init_m_mu = v_init_2.unsqueeze(-1) + mu_2.unsqueeze(-2) - 2 * v_init_mu
if self.uniform_value_precision:
#log_p_v = -0.5 * beta * v_init_m_mu
log_p_q_v += -0.5 * beta * (v_init_2.unsqueeze(-1) + mu_2.unsqueeze(-2) - 2 * v_init_mu)
else:
#log_p_v = -0.5 * beta.unsqueeze(-2) * v_init_m_mu
log_p_q_v += -0.5 * beta.unsqueeze(-2) * (v_init_2.unsqueeze(-1) + mu_2.unsqueeze(-2) - 2 * v_init_mu)
#log_pi = 0
if pi is not None:
#log_pi = torch.log(pi)
log_p_q_v += torch.log(pi)
elif self.magnitude_priors:
if self.uniform_query_precision:
alpha_tensor = alpha
else:
alpha_tensor = alpha.unsqueeze(-2)
#log_pi += 0.5 * alpha_tensor * zeta_2.unsqueeze(-2)
log_p_q_v += 0.5 * alpha_tensor * zeta_2.unsqueeze(-2)
if q_pos_emb is not None:
#log_pi = log_pi.expand(-1, -1, H, -1).clone()
#log_pi += 0.5 * alpha_tensor * q_pos_emb_2.unsqueeze(0).unsqueeze(0)
log_p_q_v += 0.5 * alpha_tensor * q_pos_emb_2.unsqueeze(0).unsqueeze(0)
if zeta_pos_emb is not None:
#log_pi += 0.5 * alpha_tensor * zeta_2.unsqueeze(-2)
#log_pi += 0.5 * alpha_tensor * zeta_pos_emb_2.unsqueeze(0).unsqueeze(0)
log_p_q_v += 0.5 * alpha_tensor * zeta_2.unsqueeze(-2)
log_p_q_v += 0.5 * alpha_tensor * zeta_pos_emb_2.unsqueeze(0).unsqueeze(0)
if self.uniform_value_precision:
beta_tensor = beta
else:
beta_tensor = beta.unsqueeze(-2)
if v_pos_emb is not None:
mu_p_v_pos_emb = mu.unsqueeze(-2) + v_pos_emb.unsqueeze(0).unsqueeze(0)
mu_p_v_pos_emb_2 = torch.sum(mu_p_v_pos_emb**2, dim=-3)
#log_pi += 0.5 * beta_tensor * mu_p_v_pos_emb_2
log_p_q_v += 0.5 * beta_tensor * mu_p_v_pos_emb_2
else:
#log_pi += 0.5 * beta_tensor * mu_2.unsqueeze(-2)
log_p_q_v += 0.5 * beta_tensor * mu_2.unsqueeze(-2)
#log_p_q_v = log_pi + log_p_q + log_p_v
# log_sum_exp trick to avoid numerical underflow
m, idx = torch.max(log_p_q_v, dim=-1, keepdim=True)
# Debugging
"""
zeta_2_max, zeta_2_max_idx = torch.max(zeta_2, dim=-1, keepdim=True)
log_pi_max, log_pi_max_idx = torch.max(log_pi, dim=-1, keepdim=True)
"""
weights = torch.exp(log_p_q_v - m)
if nonzero_wts_mask is not None:
weights = weights * nonzero_wts_mask.unsqueeze(0).unsqueeze(0).float()
sum_weights = torch.sum(weights, dim=-1, keepdim=True) + eps
weights = weights.div(sum_weights)
return weights
weights = update_weights()
if self.key_adaptation:
# Online key adaptation
for ka_iter in range(self.key_adaptation_iters):
zeta_update = torch.einsum('bgij,bgci->bgcj', weights, q)
sum_weights = torch.sum(weights, dim=-2, keepdim=True)
if zeta_prior_precision is not None:
zeta = zeta_prior_precision * zeta + alpha * zeta_update
zeta = zeta.div(zeta_prior_precision + alpha * sum_weights)
else:
zeta = zeta_update
zeta = zeta.div(sum_weights)
weights = update_weights()
wve = torch.zeros_like(mu).cuda() if torch.cuda.is_available() else torch.zeros_like(mu)
if v_fixed is not None:
# Online value belief propagation
for vbp_iter in range(self.value_belief_propagation_iters):
if torch.sum(v_fixed[:, :, -1, :]) > 0:
mu_update = torch.einsum('bgij,bgci->bgcj', weights, v_fixed[:, :, :C_v, :])
sum_weights = torch.einsum('bgij,bgi->bgj', weights, v_fixed[:, :, -1, :]).unsqueeze(-2) + eps
if mu_prior_precision is not None:
mu = mu_prior_precision * mu + beta * mu_update
mu = mu.div(mu_prior_precision + beta * sum_weights)
else:
mu = mu_update
mu = mu.div(sum_weights)
# Offset contributions from v_pos_emb with learnt parameters
if v_pos_emb is not None:
wve += torch.einsum('bgij,bgcj->bgci', weights, v_fixed[:, :, C_v:-1, :])
# Update weights
weights = update_weights()
v_updated = torch.einsum('bgij,bgcj->bgci', weights, mu) # Should we force v_updated = v_fixed at specified locs?
if v_pos_emb is not None:
wve += torch.einsum('bgij,cij->bgci', weights, v_pos_emb)
v_updated = torch.cat([v_updated, wve], dim=-1).view(N, G, C_v * 2, H)
return v_updated