in paq/generation/filtering/filterer.py [0:0]
def __init__(self,
corpus_path,
index_path,
index_id_to_db_id_path,
model_path,
batch_size,
n_queries_to_parallelize,
max_seq_len,
n_docs,
device
):
self.corpus_path = corpus_path
self.index_path = index_path
self.index_id_to_db_id_path = index_id_to_db_id_path
self.n_docs = n_docs
self.device = torch.device(f"cuda:{device}") if device is not None else torch.device("cpu")
config = AutoConfig.from_pretrained(model_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_path, config=config)
self.model = DPRQuestionEncoder(AutoModel.from_pretrained(model_path, config=config))
self.model.to(self.device)
self.model.eval()
self.batch_size = batch_size
self.n_queries_to_parallelize = n_queries_to_parallelize
self.max_seq_len = max_seq_len