submission_code/best_ctx.py (34 lines of code) (raw):
import pandas as pd
import numpy as np
import joblib
import torch
import torch.nn as nn
import torch.utils.data as data
import sentencepiece as spm
from tools import beam_search, Transformer
bos_id = 1
eos_id = 2
pad_id = 0
class BestCtxModel:
def __init__(self, config, file_path, model_path, device):
self.config = config
self.device = device
if not config.joined_vocab:
self.text_tokenizer = spm.SentencePieceProcessor(f'{file_path}/txt_bpe_ctx.model')
self.cmd_tokenizer = spm.SentencePieceProcessor(f'{file_path}/cmd_bpe_ctx.model')
else:
self.text_tokenizer = spm.SentencePieceProcessor(f'{file_path}/all_bpe_ctx.model')
self.cmd_tokenizer = self.text_tokenizer
self.model = Transformer(self.config, pad_id)
self.model.load_state_dict(torch.load(model_path, map_location=device)['model_state_dict'])
self.model.eval()
self.model.to(self.device)
self.ctx = joblib.load(f'{file_path}/man_processed')
def predict(self, text, util, beam_width):
text = text + ' ' + self.ctx.get(util, '')
text_enc = self.text_tokenizer.encode(text)
tokens = torch.tensor([bos_id] + text_enc[:self.config.max_src_len] + [eos_id]).long()
with torch.no_grad():
pred = beam_search(tokens, self.model.tr, pad_id, bos_id, eos_id, max_len=self.config.max_tgt_len, k=beam_width)
pred = [(self.cmd_tokenizer.decode(list(map(int, x))), proba) for x, proba in pred]
return pred