def compute_spouse_embeddings()

in parsers/Spouse/Spouse_Dataset_Builder.py [0:0]


    def compute_spouse_embeddings(self, bert_model, device):

        for dataset_type in ['train', 'validation', 'test']:

            with open(Path(self.processed_dataset_folder, f'formatted_{dataset_type}_dataset.pickle'), 'rb') as f:
                dataset = pickle.load(f)
            no_examples = len(dataset)

            dataset_path = Path(self.processed_dataset_folder, f'Spouse_{dataset_type}')
            if not os.path.exists(dataset_path):
                os.makedirs(dataset_path)

            # Set up BERT in eval mode
            bert_model.eval()

            example_id = 0
            for target, highlight, example in dataset:

                example_filename = Path(dataset_path, f'example_{example_id}.torch')
                example_sentences = {'target': target, 'sentences': [], 'highlighted': highlight}

                if not os.path.exists(example_filename):
                    for sentence in example:
                        # Extraxt fields of the feature i.e. the sentence
                        unique_example_id = sentence.unique_example_id[0]  # tuple of single element
                        unique_sentence_id = sentence.unique_sentence_id[0]  # tuple of single element
                        tokens = sentence.tokens
                        annotations = sentence.annotations

                        if len(tokens) == 2:  # empty sentence
                            continue

                        input_ids = torch.tensor(sentence.input_ids, dtype=torch.long).unsqueeze(dim=0)
                        input_mask = torch.tensor(sentence.input_mask, dtype=torch.long).unsqueeze(dim=0)
                        # input_type_ids = sentence.input_type_ids  # This is unnecessary

                        # Map inputs to device for BERT
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)

                        # Apply BERT
                        all_encoder_layers, _ = bert_model(input_ids, attention_mask=input_mask, token_type_ids=None)
                        all_encoder_layers = torch.stack(all_encoder_layers)

                        no_layers = all_encoder_layers.shape[0]  # each has shape batch_size, MAX_DIM_EMB, EMBEDDING_DIM
                        avg_encoder_embeddings = (
                                    torch.sum(all_encoder_layers, dim=0) / no_layers).squeeze().detach().cpu().numpy()

                        # take mean of tokens embeddings (without [CLS] and [SEP])
                        avg_encoder_embeddings = avg_encoder_embeddings[1:len(tokens)-1, :]

                        assert example_id == unique_example_id

                        tokens_mean = np.sum(avg_encoder_embeddings, axis=0) / (avg_encoder_embeddings.shape[0])
                        # Combine the sentences in a single dictionary and save as a torch file
                        sentence_dict = {'unique_sentence_id': sentence.unique_sentence_id,
                                         'example_id': example_id,
                                         'tokens_annotations': annotations[1:-1],  # without [CLS] and [SEP]
                                         'tokens_embeddings': avg_encoder_embeddings,
                                         'sentence_embeddings': np.expand_dims(tokens_mean, axis=0),
                                         'tokens': tokens[1:-1]}  # without [CLS] and [SEP]

                        # print(len(sentence_dict['tokens_annotations']), len(sentence_dict['tokens']))

                        assert len(sentence_dict['tokens_annotations']) == int(sentence_dict['tokens_embeddings'].shape[0]), (len(sentence_dict['tokens_annotations']), sentence_dict['tokens_embeddings'].shape[0])

                        example_sentences['sentences'].append(sentence_dict)

                    # Storing example dict in a torch file
                    torch.save(example_sentences, example_filename)

                print(f'Completed example {example_id + 1}/{no_examples}')
                example_id += 1
        print('')  # just add a newline between training, validation and test