in paq/generation/filtering/filterer.py [0:0]
def __init__(self,
model_path: str,
batch_size: int = 10,
device: int = 0,
max_seq_len: int = 200,
n_docs:int = 50,
):
self.device = torch.device(f"cuda:{device}") if device is not None else torch.device("cpu")
self.tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False)
self.model = src.model.FiDT5.from_pretrained(model_path)
self.model.to(self.device)
self.model.eval()
self.model.encoder = CompatableEncoderWrapper(self.model.encoder.encoder) # hack to make FID compatable with newer transformers version
self.batch_size = batch_size
self.max_seq_len = max_seq_len
self.n_docs = n_docs
self.collator = src.data.Collator(self.max_seq_len, self.tokenizer)