src/pixparse/task/task_cruller_eval_cord.py [283:324]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def text_input_to_target(self, text_input, ignore_id=-100):
        target = text_input.clone()
        # model doesn't need to predict pad token
        target[target == self.tokenizer.trunk.pad_token_id] = ignore_id
        # model doesn't need to predict prompt (for VQA)
        prompt_end_token_id = self.tokenizer.trunk.convert_tokens_to_ids(
            self.prompt_end_token
        )
        slice_id = torch.nonzero(target == prompt_end_token_id).sum() + 1
        target[:slice_id] = ignore_id
        return target

    def collate_fn(self, batch):
        """
        basic collator for PIL images, as returned by rvlcdip dataloader (among others)
        """

        # TODO move this to a __getitem__ for pickling
        tokenizer_fn = lambda x: self.tokenizer.trunk(
            x,
            add_special_tokens=False,
            return_tensors="pt",
            max_length=512,
            padding="max_length",
            truncation=True,
        ).input_ids[0]

        images = [item["image"] for item in batch]
        raw_texts = [literal_eval(item["ground_truth"])["gt_parse"] for item in batch]
        inputs_to_stack = []
        for text in raw_texts:
            tokens_from_json, _ = json2token(text, self.tokenizer.trunk.all_special_tokens, sort_json_key=False)
            inputs_to_stack.append(tokenizer_fn(
                self.task_start_token
                #+ self.tokenizer.trunk.bos_token
                + tokens_from_json
                + self.tokenizer.trunk.eos_token
            ))
        text_inputs = torch.stack(
            inputs_to_stack
        )
        targets = torch.stack([self.text_input_to_target(text) for text in text_inputs])
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



src/pixparse/task/task_cruller_finetune_CORD.py [384:425]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    def text_input_to_target(self, text_input, ignore_id=-100):
        target = text_input.clone()
        # model doesn't need to predict pad token
        target[target == self.tokenizer.trunk.pad_token_id] = ignore_id
        # model doesn't need to predict prompt (for VQA)
        prompt_end_token_id = self.tokenizer.trunk.convert_tokens_to_ids(
            self.prompt_end_token
        )
        slice_id = torch.nonzero(target == prompt_end_token_id).sum() + 1
        target[:slice_id] = ignore_id
        return target

            

    def collate_fn(self, batch):
        """
        basic collator for PIL images, as returned by rvlcdip dataloader (among others)
        """
        tokenizer_fn = lambda x: self.tokenizer.trunk(
            x,  # FIXME move this batcher/tokenizer elsewhere
            add_special_tokens=False,
            return_tensors="pt",
            max_length=512,
            padding="max_length",
            truncation=True,
        ).input_ids[0]

        images = [item["image"] for item in batch]
        raw_texts = [literal_eval(item["ground_truth"])["gt_parse"] for item in batch]
        inputs_to_stack = []
        for text in raw_texts:
            tokens_from_json, _ = json2token(text, self.tokenizer.trunk.all_special_tokens, sort_json_key=False)
            inputs_to_stack.append(tokenizer_fn(
                self.task_start_token
                #+ self.tokenizer.trunk.bos_token
                + tokens_from_json
                + self.tokenizer.trunk.eos_token
            ))
        text_inputs = torch.stack(
            inputs_to_stack
        )
        targets = torch.stack([self.text_input_to_target(text) for text in text_inputs])
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



