in aiops/RCRank/model/modules/rcrank_model.py [0:0]
def __init__(self, t_input_dim, l_input_dim, l_hidden_dim, t_hidden_him, emb_dim, sql_model=None, device=None,
plan_args=None, cross_model=None, time_model=None, cross_mean=True, cross_model_real=None, cross_model_CrossSQLPlan=None, rootcause_cross_model=None) -> None:
super().__init__()
self.time_model = time_model
self.log_model = LogModel(1, l_hidden_dim, emb_dim)
self.sql_last_emb = nn.Linear(768, emb_dim)
self.plan_last_emb = nn.Linear(768, emb_dim)
self.time_tran_emb = nn.Linear(emb_dim * 7, emb_dim)
self.common_cross_model = cross_model
self.rootcause_cross_model = rootcause_cross_model
self.pred_label_cross_list = nn.ModuleList()
self.pred_opt_cross_list = nn.ModuleList()
for _ in range(5):
self.pred_label_cross_list.append(nn.Linear(emb_dim, 1))
self.pred_opt_cross_list.append(nn.Linear(emb_dim, 1))
self.log_bn = nn.BatchNorm1d(emb_dim)
self.metrics_bn1 = nn.BatchNorm1d(7)
self.metrics_bn2 = nn.BatchNorm1d(emb_dim)
self.gate_sql = nn.ModuleList()
self.gate_sql_activate = nn.ModuleList()
self.gate_plan = nn.ModuleList()
self.gate_plan_activate = nn.ModuleList()
self.gate_log = nn.ModuleList()
self.gate_log_activate = nn.ModuleList()
self.gate_metrics = nn.ModuleList()
self.gate_metrics_activate = nn.ModuleList()
self.gate_metrics_norm = nn.ModuleList()
self.gate_out_dim = 1
for i in range(5):
gate_sql_0 = nn.Sequential()
gate_sql_0.add_module('gate_sql', nn.Linear(in_features=emb_dim, out_features=self.gate_out_dim))
self.gate_sql.append(gate_sql_0)
self.gate_sql_activate.append(nn.Sigmoid())
gate_plan_0 = nn.Sequential()
gate_plan_0.add_module('gate_plan', nn.Linear(in_features=emb_dim, out_features=self.gate_out_dim))
self.gate_plan.append(gate_plan_0)
self.gate_plan_activate.append(nn.Sigmoid())
gate_log_0 = nn.Sequential()
gate_log_0.add_module('gate_log', nn.Linear(in_features=emb_dim, out_features=emb_dim))
self.gate_log.append(gate_log_0)
self.gate_log_activate.append(nn.Sigmoid())
gate_metrics_0 = nn.Sequential()
gate_metrics_0.add_module('gate_metrics', nn.Linear(in_features=emb_dim, out_features=self.gate_out_dim))
self.gate_metrics.append(gate_metrics_0)
self.gate_metrics_activate.append(nn.Sigmoid())
self.init_params()
self.device = device
self.alignmentModel = Alignment(device=device)
self.alignmentModel.load_state_dict(torch.load('./pretrain/alignment_new/model30.pth')) # now use
self.sql_model = sql_model
self.log_model = self.alignmentModel.log_model
self.plan_model = self.alignmentModel.plan_model