in aiops/RCRank/model/modules/FuseModel/Attention.py [0:0]
def forward(self, sql, plan, log, metrics, sql_mask, plan_mask, mask=None, additional_mask=None, layer_cache=None, type=None, edge_feature=None, pair_indices=None):
query = sql
sql_key = sql
sql_value = sql
plan_key = plan
plan_value = plan
batch_size = query.size(0)
dim_per_head = self.dim_per_head
head_count = self.head_count
def shape(x):
return x.view(batch_size, -1, head_count, dim_per_head) \
.transpose(1, 2)
def unshape(x):
return x.transpose(1, 2).contiguous() \
.view(batch_size, -1, head_count * dim_per_head)
sql_key_projected = self.linear_keys(sql_key)
sql_value_projected = self.linear_values(sql_value)
query_projected = self.linear_query(query)
sql_key_shaped = shape(sql_key_projected)
sql_value_shaped = shape(sql_value_projected)
plan_key_projected = self.linear_plan_keys(plan_key)
plan_value_projected = self.linear_plan_values(plan_value)
plan_key_shaped = shape(plan_key_projected)
plan_value_shaped = shape(plan_value_projected)
query_shaped = shape(query_projected)
query_len = query_shaped.size(2)
sql_key_len = sql_key_shaped.size(2)
plan_key_len = plan_key_shaped.size(2)
# sql encoder
sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
scores = torch.matmul(sql_query_shaped, sql_key_shaped.transpose(2, 3))
top_score = scores.view(batch_size, scores.shape[1],
query_len, sql_key_len)[:, 0, :, :].contiguous()
attn = self.softmax(scores)
drop_attn = self.dropout_sql(attn)
context = torch.matmul(drop_attn, sql_value_shaped)
sql_context = unshape(context)
# plan encoder
sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
scores = torch.matmul(sql_query_shaped, plan_key_shaped.transpose(2, 3))
attn = self.softmax(scores)
drop_attn = self.dropout_plan(attn)
context = torch.matmul(drop_attn, plan_value_shaped)
plan_context = unshape(context)
# metrics encoder
if self.use_metrics:
metrics = metrics.unsqueeze(1)
metrics_key = metrics
metrics_value = metrics
metrics_key_projected = self.linear_metrics_keys(metrics_key)
metrics_value_projected = self.linear_metrics_values(metrics_value)
metrics_key_shaped = shape(metrics_key_projected)
metrics_value_shaped = shape(metrics_value_projected)
metrics_key_len = metrics_key_shaped.size(2)
sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
scores = torch.matmul(sql_query_shaped, metrics_key_shaped.transpose(2, 3))
attn = torch.sigmoid(scores)
drop_attn = self.dropout_metrics(attn)
context = torch.matmul(drop_attn, metrics_value_shaped)
metrics_context = unshape(context)
if self.use_log:
log = log.unsqueeze(1)
log_key = log
log_value = log
log_key_projected = self.linear_log_keys(log_key)
log_value_projected = self.linear_log_values(log_value)
log_key_shaped = shape(log_key_projected)
log_value_shaped = shape(log_value_projected)
sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
scores = torch.matmul(sql_query_shaped, log_key_shaped.transpose(2, 3))
attn = torch.sigmoid(scores)
drop_attn = self.dropout_log(attn)
context = torch.matmul(drop_attn, log_value_shaped)
log_context = unshape(context)
context = torch.cat([sql_context, plan_context], dim=-1)
if self.use_metrics:
context = torch.cat([context, metrics_context], dim=-1)
if self.use_log:
context = torch.cat([context, log_context], dim=-1)
output = self.final_linear(context)
return output, top_score