aiops/RCRank/model/modules/single_model.py (146 lines of code) (raw):
import torch
from torch import nn
from model.modules.QueryFormer.QueryFormer import QueryFormer
from model.modules.LogModel.log_model import LogModel
# ----------------------------------------- 单模态 -----------------------------------
class SQLOptModel(nn.Module):
def __init__(self, t_input_dim, l_input_dim, l_hidden_dim, t_hidden_him, emb_dim, sql_model=None, device=None,
plan_args=None, cross_model=None, time_model=None, cross_mean=True) -> None:
super().__init__()
self.plan_model = QueryFormer(emb_size = plan_args.embed_size ,ffn_dim = plan_args.ffn_dim, head_size = plan_args.head_size, \
dropout = plan_args.dropout, n_layers = plan_args.n_layers, \
use_sample = plan_args.use_sample, use_hist = False, \
pred_hid = emb_dim)
self.sql_last_emb = nn.Linear(768, emb_dim)
self.activation = nn.ReLU()
self.pred_label_concat = nn.Linear(emb_dim * 4, 5)
self.pred_label_cross = nn.Linear(emb_dim, 5)
self.pred_opt_concat = nn.Linear(emb_dim * 4, 5)
self.pred_opt_cross = nn.Linear(emb_dim, 5)
self.init_params()
self.sql_model = sql_model
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, sql, plan, time, log):
with torch.no_grad():
sql_emb = self.sql_model(**sql)
sql_emb = sql_emb.last_hidden_state
sql_emb = sql_emb[:, 0, :]
sql_emb = self.sql_last_emb(sql_emb)
sql_emb = self.activation(sql_emb)
pred_label = self.pred_label_cross(sql_emb)
pred_opt = self.pred_opt_cross(sql_emb)
pred_label = torch.sigmoid(pred_label)
return pred_label, pred_opt
class PlanOptModel(nn.Module):
def __init__(self, t_input_dim, l_input_dim, l_hidden_dim, t_hidden_him, emb_dim, sql_model=None, device=None,
plan_args=None, cross_model=None, time_model=None, cross_mean=True) -> None:
super().__init__()
self.plan_model = QueryFormer(emb_size = plan_args.embed_size ,ffn_dim = plan_args.ffn_dim, head_size = plan_args.head_size, \
dropout = plan_args.dropout, n_layers = plan_args.n_layers, \
use_sample = plan_args.use_sample, use_hist = False, \
pred_hid = emb_dim)
self.pred_label_cross = nn.Linear(emb_dim, 5)
self.pred_opt_cross = nn.Linear(emb_dim, 5)
self.activation = nn.ReLU()
self.init_params()
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, sql, plan, time, log):
plan_emb = self.plan_model(plan)
plan_emb = plan_emb[:, 0, :]
plan_emb = self.activation(plan_emb)
pred_label = self.pred_label_cross(plan_emb)
pred_opt = self.pred_opt_cross(plan_emb)
pred_label = torch.sigmoid(pred_label)
return pred_label, pred_opt
class LogOptModel(nn.Module):
def __init__(self, t_input_dim, l_input_dim, l_hidden_dim, t_hidden_him, emb_dim, sql_model=None, device=None,
plan_args=None, cross_model=None, time_model=None, cross_mean=True) -> None:
super().__init__()
self.log_model = LogModel(l_input_dim, l_hidden_dim, emb_dim)
self.pred_label_cross = nn.Linear(emb_dim, 5)
self.pred_opt_cross = nn.Linear(emb_dim, 5)
self.activation = nn.ReLU()
self.init_params()
self.sql_model = sql_model
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, sql, plan, time, log):
log_emb = self.log_model(log)
log_emb = self.activation(log_emb)
pred_label = self.pred_label_cross(log_emb)
pred_opt = self.pred_opt_cross(log_emb)
pred_label = torch.sigmoid(pred_label)
return pred_label, pred_opt
class TimeOptModel(nn.Module):
def __init__(self, t_input_dim, l_input_dim, l_hidden_dim, t_hidden_him, emb_dim, sql_model=None, device=None,
plan_args=None, cross_model=None, time_model=None, cross_mean=True) -> None:
super().__init__()
self.time_model = time_model
self.pred_label_cross = nn.Linear(emb_dim, 5)
self.pred_opt_cross = nn.Linear(emb_dim, 5)
self.time_tran_emb = nn.Linear(emb_dim, emb_dim)
self.activation = nn.ReLU()
self.init_params()
self.sql_model = sql_model
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, sql, plan, time, log):
time_emb = time.unsqueeze(1)
time_emb = self.time_model(time_emb)
time_emb = torch.flatten(time_emb, start_dim=1)
time_emb = self.time_tran_emb(time_emb)
time_emb = self.activation(time_emb)
pred_label = self.pred_label_cross(time_emb)
pred_opt = self.pred_opt_cross(time_emb)
pred_label = torch.sigmoid(pred_label)
return pred_label, pred_opt
# ----------------------------------------- 单模态 -----------------------------------
# Concat
class ConcatOptModel(nn.Module):
def __init__(self, t_input_dim, l_input_dim, l_hidden_dim, t_hidden_him, emb_dim, sql_model=None, device=None,
plan_args=None, cross_model=None, time_model=None, cross_mean=True) -> None:
super().__init__()
self.plan_model = QueryFormer(emb_size = plan_args.embed_size ,ffn_dim = plan_args.ffn_dim, head_size = plan_args.head_size, \
dropout = plan_args.dropout, n_layers = plan_args.n_layers, \
use_sample = plan_args.use_sample, use_hist = False, \
pred_hid = emb_dim)
self.time_model = time_model
self.log_model = LogModel(l_input_dim, l_hidden_dim, emb_dim)
self.sql_last_emb = nn.Linear(768, emb_dim)
self.time_tran_emb = nn.Linear(emb_dim, emb_dim)
self.pred_label_concat = nn.Linear(emb_dim * 4, 5)
self.pred_label_cross = nn.Linear(emb_dim, 5)
self.pred_opt_concat = nn.Linear(emb_dim * 4, 5)
self.pred_opt_cross = nn.Linear(emb_dim, 5)
self.cross_mean = cross_mean
self.init_params()
self.sql_model = sql_model
def init_params(self):
for p in self.parameters():
if p.dim() > 1:
nn.init.xavier_uniform_(p)
def forward(self, sql, plan, time, log):
with torch.no_grad():
sql_emb = self.sql_model(**sql)
sql_emb = sql_emb.last_hidden_state
plan_emb = self.plan_model(plan)
log_emb = self.log_model(log)
time_emb = time.unsqueeze(1)
time_emb = self.time_model(time_emb)
sql_emb = sql_emb[:, 0, :]
sql_emb = self.sql_last_emb(sql_emb)
plan_emb = plan_emb[:, 0, :]
time_emb = torch.flatten(time_emb, start_dim=1)
time_emb = self.time_tran_emb(time_emb)
emb = torch.cat([sql_emb, plan_emb, log_emb, time_emb], dim=1)
pred_label = self.pred_label_concat(emb)
pred_opt = self.pred_opt_concat(emb)
pred_label = torch.sigmoid(pred_label)
return pred_label, pred_opt