def process_embeddings()

in parsers/MovieReview/MovieReview_Dataset_Builder.py [0:0]


    def process_embeddings(self, embedding_dim):
        for dataset_type in ['train', 'test']:

            dataset_path = Path(self.processed_dataset_folder, f'MovieReview_{dataset_type}')
            processed_path = Path(self.processed_dataset_folder, f'MovieReview_{dataset_type}', 'processed')

            example_no = 0
            for filename in [f for f in sorted(os.listdir(dataset_path)) if '.torch' in f]:
                example_no += 1
                example_filename = filename
                example = torch.load(Path(dataset_path, example_filename))
                target, example = example['target'], example['sentences']

                processed_example = {'target': torch.tensor([target], dtype=torch.long), 'tokens': [],
                                     'tokens_embeddings': None, 'sentence_embeddings': None,
                                     'tokens_annotations': None}


                # This will be needed to compute a single indexing for all tokens in the DOCUMENT
                starting_token_idx = 0
                sentence_idx = -1  # used by reference embeddings

                for sentence in example:
                    sentence_idx += 1
                    unique_sentence_id = sentence['unique_sentence_id']
                    sentence_example_id = sentence['example_id']

                    # The baseline will take the mean of the embeddings at runtime!
                    tokens_annotations = torch.from_numpy(np.array(sentence['tokens_annotations'])).long()  # CLS and SEP already removed
                    tokens_embeddings = torch.from_numpy(sentence['tokens_embeddings'])  # CLS and SEP already removed
                    sentence_embeddings = torch.from_numpy(sentence['sentence_embeddings'])  # CLS and SEP already removed
                    sentence_tokens = sentence['tokens']  # CLS and SEP already removed

                    # print(tokens_annotations.shape, tokens_embeddings.shape, sentence_embeddings.shape, len(sentence_tokens))

                    # Construct ordered pairs of tokens (all of them for now)
                    no_tokens = len(sentence_tokens)

                    # Now update example info by concatenating everything
                    for key, val in [('tokens_embeddings', tokens_embeddings),
                                     ('tokens_annotations', tokens_annotations),
                                     ('sentence_embeddings', sentence_embeddings)]:

                        if processed_example[key] is None:
                            processed_example[key] = val
                        else:
                            processed_example[key] = torch.cat((processed_example[key], val), dim=0)

                    starting_token_idx += no_tokens

                    processed_example['tokens'].extend(sentence_tokens)

                # All train samples are highlighted
                store_path = processed_path

                if not os.path.exists(store_path):
                    os.makedirs(store_path)

                torch.save(processed_example, Path(store_path, f'{example_filename[:-6]}_processed.torch'))

                if example_no % 1000 == 0:
                    print(f'Processed {example_no} examples')