def process_embeddings()

in parsers/Spouse/Spouse_Finetune_Dataset_Builder.py [0:0]


    def process_embeddings(self):
        for dataset_type in ['train', 'validation', 'test']:

            dataset_path = Path(self.processed_dataset_folder, f'Spouse_{dataset_type}')
            processed_path = Path(self.processed_dataset_folder, f'Spouse_{dataset_type}', 'processed')

            example_no = 0
            for filename in [f for f in sorted(os.listdir(dataset_path)) if '.torch' in f]:
                # print(filename)
                example_no += 1
                example_filename = filename
                example = torch.load(Path(dataset_path, example_filename))
                target, highlighted, example = example['target'], example['highlighted'], example['sentences']

                processed_example = {'target': torch.tensor([target], dtype=torch.long), 'tokens': [],
                                     'input_ids': None,
                                     'input_mask': None,
                                     'tokens_annotations': None}


                # This will be needed to compute a single indexing for all tokens in the DOCUMENT
                starting_token_idx = 0
                sentence_idx = -1  # used by reference embeddings

                for sentence in example:
                    sentence_idx += 1
                    unique_sentence_id = sentence['unique_sentence_id']
                    sentence_example_id = sentence['example_id']

                    # The baseline will take the mean of the embeddings at runtime!
                    tokens_annotations = torch.from_numpy(np.array(sentence['tokens_annotations'])).long()  # CLS and SEP already removed

                    input_ids = sentence['input_ids']
                    input_mask = sentence['input_mask']
                    sentence_tokens = sentence['tokens']  # CLS and SEP already removed

                    # print(tokens_annotations.shape, tokens_embeddings.shape, sentence_embeddings.shape, len(sentence_tokens))

                    # Construct ordered pairs of tokens (all of them for now)
                    no_tokens = len(sentence_tokens)

                    # Now update example info by concatenating everything
                    for key, val in [('input_ids', input_ids),
                                     ('tokens_annotations', tokens_annotations),
                                     ('input_mask', input_mask)]:

                        if processed_example[key] is None:
                            processed_example[key] = val
                        else:
                            processed_example[key] = torch.cat((processed_example[key], val), dim=0)

                    starting_token_idx += no_tokens

                    processed_example['tokens'].extend(sentence_tokens)

                if highlighted == 1:
                    store_path = Path(processed_path, 'highlighted')
                else:
                    store_path = processed_path

                if not os.path.exists(store_path):
                    os.makedirs(store_path)

                torch.save(processed_example, Path(store_path, f'{example_filename[:-6]}_processed.torch'))

                if example_no % 1000 == 0:
                    print(f'Processed {example_no} examples')