aiops/RCRank/model/modules/QueryFormer/QueryFormer.py (261 lines of code) (raw):

import torch from torch.utils.data import Dataset import torch.nn as nn import torch.nn.functional as F class Prediction(nn.Module): def __init__(self, in_feature = 69, hid_units = 256, contract = 1, mid_layers = True, res_con = True): super(Prediction, self).__init__() self.mid_layers = mid_layers self.res_con = res_con self.out_mlp1 = nn.Linear(in_feature, hid_units) self.mid_mlp1 = nn.Linear(hid_units, hid_units//contract) self.mid_mlp2 = nn.Linear(hid_units//contract, hid_units) def forward(self, features): hid = F.relu(self.out_mlp1(features)) if self.mid_layers: mid = F.relu(self.mid_mlp1(hid)) mid = F.relu(self.mid_mlp2(mid)) if self.res_con: hid = hid + mid else: hid = mid out = hid return out class FeatureEmbed(nn.Module): def __init__(self, embed_size=32, tables = 1500, types=1500, joins = 1500, columns= 3000, \ ops=4, use_sample = True, use_hist = True, bin_number = 50): super(FeatureEmbed, self).__init__() self.use_sample = use_sample self.embed_size = embed_size self.use_hist = use_hist self.bin_number = bin_number self.typeEmbed = nn.Embedding(types, embed_size) self.tableEmbed = nn.Embedding(tables, embed_size) self.columnEmbed = nn.Embedding(columns, embed_size) self.opEmbed = nn.Embedding(ops, embed_size//8) self.linearFilter2 = nn.Linear(embed_size+embed_size//8, embed_size+embed_size//8) self.linearFilter = nn.Linear(embed_size+embed_size//8, embed_size+embed_size//8) self.linearType = nn.Linear(embed_size, embed_size) self.linearJoin = nn.Linear(embed_size, embed_size) self.linearSample = nn.Linear(1000, embed_size) self.linearHist = nn.Linear(bin_number, embed_size) self.joinEmbed = nn.Embedding(joins, embed_size) use_hist = False self.use_hist = False if use_hist: self.project = nn.Linear(embed_size*5 + embed_size//8+1 + 4, embed_size*5 + embed_size//8+1 + 4) else: self.project = nn.Linear(embed_size*4 + embed_size//8 + 4, embed_size*4 + embed_size//8 + 4) def forward(self, feature): typeId, joinId, filtersId, filtersMask, table_sample, cost = torch.split(feature,(1,1,40,20,1001, 4), dim = -1) typeEmb = self.getType(typeId) joinEmb = self.getJoin(joinId) filterEmbed = self.getFilter(filtersId, filtersMask) tableEmb = self.getTable(table_sample) histEmb = None if self.use_hist: final = torch.cat((typeEmb, filterEmbed, joinEmb, tableEmb, histEmb, cost), dim = 1) else: final = torch.cat((typeEmb, filterEmbed, joinEmb, tableEmb, cost), dim = 1) final = F.leaky_relu(self.project(final)) return final def getType(self, typeId): emb = self.typeEmbed(typeId.long()) return emb.squeeze(1) def getTable(self, table_sample): table, sample = torch.split(table_sample,(1,1000), dim = -1) emb = self.tableEmbed(table.long()).squeeze(1) if self.use_sample: emb += self.linearSample(sample) return emb def getJoin(self, joinId): emb = self.joinEmbed(joinId.long()) return emb.squeeze(1) def getHist(self, hists, filtersMask): histExpand = hists.view(-1,self.bin_number,3).transpose(1,2) emb = self.linearHist(histExpand) emb[~filtersMask.bool()] = 0. num_filters = torch.sum(filtersMask,dim = 1) total = torch.sum(emb, dim = 1) avg = total / num_filters.view(-1,1) return avg def getFilter(self, filtersId, filtersMask): filterExpand = filtersId.view(-1,2,20).transpose(1,2) colsId = filterExpand[:,:,0].long() opsId = filterExpand[:,:,1].long() col = self.columnEmbed(colsId) op = self.opEmbed(opsId) concat = torch.cat((col, op), dim = -1) concat = F.leaky_relu(self.linearFilter(concat)) concat = F.leaky_relu(self.linearFilter2(concat)) concat[~filtersMask.bool()] = 0. num_filters = torch.sum(filtersMask,dim = 1) total = torch.sum(concat, dim = 1) avg = total / num_filters.view(-1,1) return avg class QueryFormer(nn.Module): def __init__(self, emb_size = 32 ,ffn_dim = 32, head_size = 8, \ dropout = 0.1, attention_dropout_rate = 0.1, n_layers = 8, \ use_sample = True, use_hist = True, bin_number = 50, \ pred_hid = 256, input_size = 1067 ): super(QueryFormer,self).__init__() use_hist = False if use_hist: hidden_dim = emb_size * 5 + emb_size //8 + 1 + 4 else: hidden_dim = emb_size * 4 + emb_size //8 + 4 self.hidden_dim = hidden_dim self.head_size = head_size self.use_sample = use_sample self.use_hist = use_hist self.input_size = input_size self.rel_pos_encoder = nn.Embedding(64, head_size, padding_idx=0) self.height_encoder = nn.Embedding(64, hidden_dim, padding_idx=0) self.input_dropout = nn.Dropout(dropout) encoders = [EncoderLayer(hidden_dim, ffn_dim, dropout, attention_dropout_rate, head_size) for _ in range(n_layers)] self.layers = nn.ModuleList(encoders) self.final_ln = nn.LayerNorm(hidden_dim) self.super_token = nn.Embedding(1, hidden_dim) self.super_token_virtual_distance = nn.Embedding(1, head_size) self.embbed_layer = FeatureEmbed(emb_size, use_sample = use_sample, use_hist = use_hist, bin_number = bin_number) self.pred = Prediction(hidden_dim, pred_hid) self.pred_ln = nn.LayerNorm(pred_hid) self.pred2 = Prediction(hidden_dim, pred_hid) def forward(self, batched_data): attn_bias, rel_pos, x = batched_data["attn_bias"], batched_data["rel_pos"], batched_data["x"] heights = batched_data["heights"] n_batch, n_node = x.size()[:2] tree_attn_bias = attn_bias.clone() tree_attn_bias = tree_attn_bias.unsqueeze(1).repeat(1, self.head_size, 1, 1) rel_pos_bias = self.rel_pos_encoder(rel_pos).permute(0, 3, 1, 2) tree_attn_bias[:, :, 1:, 1:] = tree_attn_bias[:, :, 1:, 1:] + rel_pos_bias t = self.super_token_virtual_distance.weight.view(1, self.head_size, 1) tree_attn_bias[:, :, 1:, 0] = tree_attn_bias[:, :, 1:, 0] + t tree_attn_bias[:, :, 0, :] = tree_attn_bias[:, :, 0, :] + t x_view = x.view(-1, self.input_size) node_feature = self.embbed_layer(x_view).view(n_batch,-1, self.hidden_dim) node_feature = node_feature + self.height_encoder(heights) super_token_feature = self.super_token.weight.unsqueeze(0).repeat(n_batch, 1, 1) super_node_feature = torch.cat([super_token_feature, node_feature], dim=1) output = self.input_dropout(super_node_feature) for enc_layer in self.layers: output = enc_layer(output, tree_attn_bias) output = self.final_ln(output) output = self.pred(output) output = self.pred_ln(output) return output class QueryFormerBert(nn.Module): def __init__(self, emb_size = 32 ,ffn_dim = 32, head_size = 8, \ dropout = 0.1, attention_dropout_rate = 0.1, n_layers = 8, \ use_sample = True, use_hist = True, bin_number = 50, \ pred_hid = 256, input_size = 1067 ): super(QueryFormerBert,self).__init__() use_hist = False hidden_dim = 768 self.hidden_dim = hidden_dim self.head_size = head_size self.use_sample = use_sample self.use_hist = use_hist self.input_size = input_size self.rel_pos_encoder = nn.Embedding(64, head_size, padding_idx=0) self.height_encoder = nn.Embedding(64, hidden_dim, padding_idx=0) self.input_dropout = nn.Dropout(dropout) encoders = [EncoderLayer(hidden_dim, ffn_dim, dropout, attention_dropout_rate, head_size) for _ in range(n_layers)] self.layers = nn.ModuleList(encoders) self.final_ln = nn.LayerNorm(hidden_dim) self.super_token = nn.Embedding(1, hidden_dim) self.super_token_virtual_distance = nn.Embedding(1, head_size) self.embbed_layer = FeatureEmbed(emb_size, use_sample = use_sample, use_hist = use_hist, bin_number = bin_number) self.pred = Prediction(hidden_dim, pred_hid) self.pred_ln = nn.LayerNorm(pred_hid) self.pred2 = Prediction(hidden_dim, pred_hid) def forward(self, batched_data): attn_bias, rel_pos, x = batched_data["attn_bias"], batched_data["rel_pos"], batched_data["x"] heights = batched_data["heights"] n_batch, n_node = x.size()[:2] tree_attn_bias = attn_bias.clone() tree_attn_bias = tree_attn_bias.unsqueeze(1).repeat(1, self.head_size, 1, 1) rel_pos_bias = self.rel_pos_encoder(rel_pos).permute(0, 3, 1, 2) tree_attn_bias[:, :, 1:, 1:] = tree_attn_bias[:, :, 1:, 1:] + rel_pos_bias t = self.super_token_virtual_distance.weight.view(1, self.head_size, 1) tree_attn_bias[:, :, 1:, 0] = tree_attn_bias[:, :, 1:, 0] + t tree_attn_bias[:, :, 0, :] = tree_attn_bias[:, :, 0, :] + t node_feature = x node_feature = node_feature + self.height_encoder(heights) super_token_feature = self.super_token.weight.unsqueeze(0).repeat(n_batch, 1, 1) super_node_feature = torch.cat([super_token_feature, node_feature], dim=1) output = self.input_dropout(super_node_feature) for enc_layer in self.layers: output = enc_layer(output, tree_attn_bias) output = self.final_ln(output) output = self.pred(output) output = self.pred_ln(output) return output class FeedForwardNetwork(nn.Module): def __init__(self, hidden_size, ffn_size, dropout_rate): super(FeedForwardNetwork, self).__init__() self.layer1 = nn.Linear(hidden_size, ffn_size) self.gelu = nn.GELU() self.layer2 = nn.Linear(ffn_size, hidden_size) def forward(self, x): x = self.layer1(x) x = self.gelu(x) x = self.layer2(x) return x class MultiHeadAttention(nn.Module): def __init__(self, hidden_size, attention_dropout_rate, head_size): super(MultiHeadAttention, self).__init__() self.head_size = head_size self.att_size = att_size = hidden_size // head_size self.scale = att_size ** -0.5 self.linear_q = nn.Linear(hidden_size, head_size * att_size) self.linear_k = nn.Linear(hidden_size, head_size * att_size) self.linear_v = nn.Linear(hidden_size, head_size * att_size) self.att_dropout = nn.Dropout(attention_dropout_rate) self.output_layer = nn.Linear(head_size * att_size, hidden_size) def forward(self, q, k, v, attn_bias=None): orig_q_size = q.size() d_k = self.att_size d_v = self.att_size batch_size = q.size(0) q = self.linear_q(q).view(batch_size, -1, self.head_size, d_k) k = self.linear_k(k).view(batch_size, -1, self.head_size, d_k) v = self.linear_v(v).view(batch_size, -1, self.head_size, d_v) q = q.transpose(1, 2) v = v.transpose(1, 2) k = k.transpose(1, 2).transpose(2, 3) q = q * self.scale x = torch.matmul(q, k) if attn_bias is not None: x = x + attn_bias x = torch.softmax(x, dim=3) x = self.att_dropout(x) x = x.matmul(v) x = x.transpose(1, 2).contiguous() x = x.view(batch_size, -1, self.head_size * d_v) x = self.output_layer(x) assert x.size() == orig_q_size return x class EncoderLayer(nn.Module): def __init__(self, hidden_size, ffn_size, dropout_rate, attention_dropout_rate, head_size): super(EncoderLayer, self).__init__() self.self_attention_norm = nn.LayerNorm(hidden_size) self.self_attention = MultiHeadAttention(hidden_size, attention_dropout_rate, head_size) self.self_attention_dropout = nn.Dropout(dropout_rate) self.ffn_norm = nn.LayerNorm(hidden_size) self.ffn = FeedForwardNetwork(hidden_size, ffn_size, dropout_rate) self.ffn_dropout = nn.Dropout(dropout_rate) def forward(self, x, attn_bias=None): y = self.self_attention_norm(x) y = self.self_attention(y, y, y, attn_bias) y = self.self_attention_dropout(y) x = x + y y = self.ffn_norm(x) y = self.ffn(y) y = self.ffn_dropout(y) x = x + y return x