def __init__()

in paq/generation/filtering/filterer.py [0:0]


    def __init__(self,
                 model_path: str,
                 batch_size: int = 10,
                 device: int = 0,
                 max_seq_len: int = 200,
                 n_docs:int = 50,
                 ):
        self.device = torch.device(f"cuda:{device}") if device is not None else torch.device("cpu")
        self.tokenizer = transformers.T5Tokenizer.from_pretrained('t5-base', return_dict=False)
        self.model = src.model.FiDT5.from_pretrained(model_path)

        self.model.to(self.device)
        self.model.eval()
        self.model.encoder = CompatableEncoderWrapper(self.model.encoder.encoder) # hack to make FID compatable with newer transformers version
        self.batch_size = batch_size
        self.max_seq_len = max_seq_len
        self.n_docs = n_docs
        self.collator = src.data.Collator(self.max_seq_len, self.tokenizer)