aiops/RCRank/pretrain/alignment_new/pretrain.py (132 lines of code) (raw):
import json
import torch.optim as optim
import sys
sys.path.append("../..")
from model.modules.LogModel.log_model import LogModel
import torch.nn as nn
import torch
from model.modules.QueryFormer.utils import *
from model.modules.QueryFormer.QueryFormer import QueryFormer
from transformers import BertTokenizer, BertModel
import pickle
class Predict(nn.Module):
def __init__(self, input_dim, model_dim, num_heads, ff_dim, dropout=0.1):
super(Predict, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(d_model=model_dim, nhead=num_heads, dim_feedforward=ff_dim, dropout=dropout,batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=6)
self.input_fc = nn.Linear(input_dim,1100,bias=False)
self.input_fc1 = nn.Linear(1100,model_dim,bias=False)
def forward(self, x):
x = self.input_fc(x)
x = self.input_fc1(x)
x = self.transformer_encoder(x)
return x
class Alignment(nn.Module):
def __init__(self,device):
super(Alignment, self).__init__()
self.flatten = nn.Flatten()
self.plan_model = QueryFormer(pred_hid=32)
self.sql_model = BertModel.from_pretrained("./bert-base-uncased")
encoder_layer = nn.TransformerEncoderLayer(d_model=768, nhead=8, dim_feedforward=2048, dropout=0.1, batch_first=True)
self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=2)
self.tokenizer = BertTokenizer.from_pretrained("./bert-base-uncased")
self.log_model = LogModel(input_dim=13, hidden_dim = 64, output_dim = 32)
self.concat_dim_mask_plan = 17600
self.predict_mask_plan = Predict(input_dim=self.concat_dim_mask_plan, model_dim=1024, num_heads=8, ff_dim=2048)
self.Linear_mask_plan = nn.Linear(1024, 1067)
self.concat_dim_mask_sql = 17600
self.predict_mask_sql= Predict(input_dim=self.concat_dim_mask_sql, model_dim=768, num_heads=8, ff_dim=2048)
self.Linear_mask_sql= nn.Linear(768, 768)
self.device = device
def forward(self, plan,sql, log,dic,mod):
plan = self.plan_model(plan)
sql = self.tokenizer(sql, return_tensors="pt", padding=True, truncation=True, max_length=512).to(self.device)
sql = self.sql_model(**sql).last_hidden_state
sql = self.transformer_encoder(sql).mean(dim=1)
log = self.log_model(log)
plan = self.flatten(plan)
concatenated_vector = torch.cat((plan, sql, log, dic), dim=1).unsqueeze(1)
if mod == 'mask_plan':
transformer_output = self.predict_mask_plan(concatenated_vector)
transformer_output = transformer_output[:, 0, :]
predicted_vector = self.Linear_mask_plan(transformer_output)
elif mod == 'mask_sql':
transformer_output = self.predict_mask_sql(concatenated_vector)
transformer_output = transformer_output[:, 0, :]
predicted_vector = self.Linear_mask_sql(transformer_output)
pass
return predicted_vector
if __name__ == "__main__":
f1=open('mask_table_plan/encoding_1w.pickle','rb')
encoding_mask_plan=json.dumps(pickle.load(f1).idx2table)
f2=open('mask_table_sql/encoding1w.pickle','rb')
encoding_mask_sql=json.dumps(pickle.load(f2).idx2table)
device = 'cuda:2'
tokenizer = BertTokenizer.from_pretrained("./bert-base-uncased")
bert = BertModel.from_pretrained("./bert-base-uncased").to(device)
encoding_mask_plan = tokenizer(encoding_mask_plan, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
encoding_mask_sql = tokenizer(encoding_mask_sql, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device)
encoding={'mask_plan':encoding_mask_plan,'mask_sql':encoding_mask_sql}
model = Alignment(device).to(device)
model.load_state_dict(torch.load('./model6.pth'))
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
epochs=500
saved = 0
best=0.0
best_val_loss = float('inf')
batch_size=18
batch_loss={'mask_sql':0.0,'mask_plan':0.0}
for epoch in range(epochs):
if epoch <= 6:
print(epoch)
else:
i=0
size=0
model.train()
epoch_loss = 0.0
file = {}
file['mask_plan'] = open('mask_table_plan/plan_mask_1w.txt','r')
file['mask_sql'] = open('mask_table_sql/sql_mask1w.txt','r')
batch_list={}
while i <= 10110:
for mod in ['mask_sql','mask_plan']:
for _ in range(batch_size):
line = file[mod].readline()
line=line.strip()
if not line:
break
(query,x1,attn_bias,rel_pos,height,log_all,predict) = json.loads(line)
size=size+1
if len(batch_list.keys()) == 0:
batch_list['query']=[query]
plan={}
plan['x']= torch.tensor(x1).to(torch.float32).to(device)
plan["attn_bias"] = torch.tensor(attn_bias).to(device)
plan["rel_pos"] = torch.tensor(rel_pos).to(device)
plan["heights"] = torch.tensor(height).to(device)
batch_list['plan']=plan
batch_list['log'] = torch.tensor(log_all).unsqueeze(0).to(device)
batch_list['predict'] = torch.tensor(predict).to(device).unsqueeze(0)
else:
batch_list['query'].append(query)
batch_list['plan']['x'] = torch.cat((batch_list['plan']['x'],torch.tensor(x1).to(torch.float32).to(device)),dim=0)
batch_list['plan']["attn_bias"]= torch.cat([batch_list['plan']['attn_bias'],torch.tensor(attn_bias).to(device)],dim=0)
batch_list['plan']["rel_pos"]= torch.cat([batch_list['plan']['rel_pos'],torch.tensor(rel_pos).to(device)],dim=0)
batch_list['plan']["heights"]= torch.cat([batch_list['plan']['heights'],torch.tensor(height).to(device)],dim=0 )
batch_list['log'] =torch.cat([batch_list['log'],torch.tensor(log_all).unsqueeze(0).to(device)],dim=0 )
batch_list['predict'] =torch.cat([batch_list['predict'], torch.tensor(predict).unsqueeze(0).to(device)],dim=0)
optimizer.zero_grad()
sql, plan, log,predict = batch_list["query"], batch_list["plan"],batch_list["log"],batch_list['predict']
dic = bert(**encoding[mod]).pooler_output
dic = dic.repeat(len(sql), 1)
output = model(plan,sql,log,dic,mod)
loss = criterion(output, predict)
loss.backward()
optimizer.step()
batch_loss[mod] = loss.item()
batch_list={}
if mod == 'mask_sql':
i=i+batch_size
if (epoch+1) %3 == 0:
torch.save(model.state_dict(), f'model{(epoch+1)}.pth')