aiops/RCRank/model/modules/FuseModel/Attention.py (153 lines of code) (raw):
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
class SSP(nn.Softplus):
def __init__(self, beta=1, threshold=20):
super(SSP, self).__init__(beta, threshold)
def forward(self, input):
sp0 = F.softplus(torch.zeros(1), self.beta, self.threshold).item()
return F.softplus(input, self.beta, self.threshold) - sp0
class PositionwiseFeedForward(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.1):
super(PositionwiseFeedForward, self).__init__()
self.w_1 = nn.Linear(d_model, d_ff)
self.w_2 = nn.Linear(d_ff, d_model)
self.layer_norm = LayerNorm(d_model)
self.dropout_1 = nn.Dropout(dropout)
self.relu = nn.ReLU()
self.dropout_2 = nn.Dropout(dropout)
def forward(self, x):
inter = self.dropout_1(self.relu(self.w_1(self.layer_norm(x))))
output = self.dropout_2(self.w_2(inter))
return output + x
class MultiHeadedAttention(nn.Module):
def __init__(self, head_count, model_dim, dropout=0.1, use_metrics=True, use_log=True):
self.use_metrics = use_metrics
self.use_log = use_log
assert model_dim % head_count == 0
self.dim_per_head = model_dim // head_count
self.model_dim = model_dim
super(MultiHeadedAttention, self).__init__()
self.head_count = head_count
self.linear_keys = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_values = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_query = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_plan_keys = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_plan_values = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_log_keys = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_log_values = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_metrics_keys = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.linear_metrics_values = nn.Linear(model_dim,
head_count * self.dim_per_head)
self.softmax = nn.Softmax(dim=-1)
self.dropout_sql = nn.Dropout(dropout)
self.dropout_plan = nn.Dropout(dropout)
self.dropout_log = nn.Dropout(dropout)
self.dropout_metrics = nn.Dropout(dropout)
model_num = 4
if not self.use_metrics: model_num -= 1
if not self.use_log: model_num -= 1
self.final_linear = nn.Linear(model_dim * model_num, model_dim)
self.edge_project = nn.Sequential(nn.Linear(model_dim, model_dim),
SSP(),
nn.Linear(model_dim, model_dim // 2))
self.edge_update = nn.Sequential(nn.Linear(model_dim * 2, model_dim),
SSP(),
nn.Linear(model_dim, model_dim))
def forward(self, sql, plan, log, metrics, sql_mask, plan_mask, mask=None, additional_mask=None, layer_cache=None, type=None, edge_feature=None, pair_indices=None):
query = sql
sql_key = sql
sql_value = sql
plan_key = plan
plan_value = plan
batch_size = query.size(0)
dim_per_head = self.dim_per_head
head_count = self.head_count
def shape(x):
return x.view(batch_size, -1, head_count, dim_per_head) \
.transpose(1, 2)
def unshape(x):
return x.transpose(1, 2).contiguous() \
.view(batch_size, -1, head_count * dim_per_head)
sql_key_projected = self.linear_keys(sql_key)
sql_value_projected = self.linear_values(sql_value)
query_projected = self.linear_query(query)
sql_key_shaped = shape(sql_key_projected)
sql_value_shaped = shape(sql_value_projected)
plan_key_projected = self.linear_plan_keys(plan_key)
plan_value_projected = self.linear_plan_values(plan_value)
plan_key_shaped = shape(plan_key_projected)
plan_value_shaped = shape(plan_value_projected)
query_shaped = shape(query_projected)
query_len = query_shaped.size(2)
sql_key_len = sql_key_shaped.size(2)
plan_key_len = plan_key_shaped.size(2)
# sql encoder
sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
scores = torch.matmul(sql_query_shaped, sql_key_shaped.transpose(2, 3))
top_score = scores.view(batch_size, scores.shape[1],
query_len, sql_key_len)[:, 0, :, :].contiguous()
attn = self.softmax(scores)
drop_attn = self.dropout_sql(attn)
context = torch.matmul(drop_attn, sql_value_shaped)
sql_context = unshape(context)
# plan encoder
sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
scores = torch.matmul(sql_query_shaped, plan_key_shaped.transpose(2, 3))
attn = self.softmax(scores)
drop_attn = self.dropout_plan(attn)
context = torch.matmul(drop_attn, plan_value_shaped)
plan_context = unshape(context)
# metrics encoder
if self.use_metrics:
metrics = metrics.unsqueeze(1)
metrics_key = metrics
metrics_value = metrics
metrics_key_projected = self.linear_metrics_keys(metrics_key)
metrics_value_projected = self.linear_metrics_values(metrics_value)
metrics_key_shaped = shape(metrics_key_projected)
metrics_value_shaped = shape(metrics_value_projected)
metrics_key_len = metrics_key_shaped.size(2)
sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
scores = torch.matmul(sql_query_shaped, metrics_key_shaped.transpose(2, 3))
attn = torch.sigmoid(scores)
drop_attn = self.dropout_metrics(attn)
context = torch.matmul(drop_attn, metrics_value_shaped)
metrics_context = unshape(context)
if self.use_log:
log = log.unsqueeze(1)
log_key = log
log_value = log
log_key_projected = self.linear_log_keys(log_key)
log_value_projected = self.linear_log_values(log_value)
log_key_shaped = shape(log_key_projected)
log_value_shaped = shape(log_value_projected)
sql_query_shaped = query_shaped / math.sqrt(dim_per_head)
scores = torch.matmul(sql_query_shaped, log_key_shaped.transpose(2, 3))
attn = torch.sigmoid(scores)
drop_attn = self.dropout_log(attn)
context = torch.matmul(drop_attn, log_value_shaped)
log_context = unshape(context)
context = torch.cat([sql_context, plan_context], dim=-1)
if self.use_metrics:
context = torch.cat([context, metrics_context], dim=-1)
if self.use_log:
context = torch.cat([context, log_context], dim=-1)
output = self.final_linear(context)
return output, top_score
class LayerNorm(nn.Module):
def __init__(self, features, eps=1e-6):
super(LayerNorm, self).__init__()
self.a_2 = nn.Parameter(torch.ones(features))
self.b_2 = nn.Parameter(torch.zeros(features))
self.eps = eps
def forward(self, x):
mean = x.mean(-1, keepdim=True)
std = x.std(-1, keepdim=True)
return self.a_2 * (x - mean) / (std + self.eps) + self.b_2