submission_code/infer.py (41 lines of code) (raw):
import sentencepiece as spm
import torch
import numpy as np
from dataclasses import dataclass
from copy import copy
from preprocessing import clean_text
from best_ctx import BestCtxModel
from best_util import BestUtilModel
EXPERIMENT_NAME = '-'
import config_clf
util_config = config_clf.Config()
import config_ctx
ctx_config = config_ctx.Config()
bos_id = 1
eos_id = 2
pad_id = 0
class Predictor:
def __init__(self, path):
self.util_model = BestUtilModel(util_config, path, f'{path}/util_model.pth', 'cpu')
self.ctx_model = BestCtxModel(ctx_config, path, f'{path}/ctx_model.pth', 'cpu')
def predict_many(self, texts, result_cnt):
alpha = 0.6
n_utils = 5
beam_width = 5
text_cleaned = [clean_text(x) for x in texts]
pred_utils = self.util_model.predict_many(text_cleaned, n_utils)
result = []
with torch.no_grad():
for i in range(len(text_cleaned)):
candidates = []
for j in range(n_utils):
util, util_proba = pred_utils[i][j]
pred = self.ctx_model.predict(text_cleaned[i], util, beam_width)
for pred_cmd, ctx_proba in pred:
joined_proba = (1 - alpha) * util_proba + alpha * ctx_proba
candidates.append((pred_cmd,joined_proba))
candidates = sorted(candidates, key=lambda x: -x[1])[:result_cnt]
candidates = [x[0] for x in candidates]
result.append(candidates)
return result
predictor = Predictor('/nlc2cmd/src/submission_code')
# predictor = Predictor('./')