def create_pretraining_instance()

in table_bert/vertical/input_formatter.py [0:0]


    def create_pretraining_instance(self, context: List[str], table: Table, example: Example):
        assert self.config.table_mask_strategy == 'column'

        row_instances = self.get_input(context, table)['rows']

        num_maskable_columns = min(len(row_inst['column_spans']) for row_inst in row_instances)
        num_column_to_mask = max(1, math.ceil(num_maskable_columns * self.config.masked_column_prob))
        columns_to_mask = sorted(random.sample(list(range(num_maskable_columns)), num_column_to_mask))

        masked_column_token_indices_list = []
        masked_column_token_column_ids = []

        masked_cell_token_indices_list = []
        masked_cell_token_column_ids_list = []
        masked_cell_token_labels_list = []
        for row_id, row_instance in enumerate(row_instances):
            maskable_column_token_indices = [
                (
                        list(range(*span['column_name'])) +
                        list(range(*span['type']))
                )
                for col_id, span
                in enumerate(row_instance['column_spans'])
            ]

            masked_column_token_indices = [
                token_idx
                for col_id in columns_to_mask
                for token_idx in maskable_column_token_indices[col_id]
            ]

            masked_cell_token_indices = [
                range(*row_instance['column_spans'][col_id]['value'])
                for col_id in columns_to_mask
            ]

            # masked_cell_token_column_ids = [
            #     col_id
            #     for col_id in columns_to_mask
            #     for token_idx in masked_cell_token_indices[col_id]
            # ]

            masked_cell_token_indices = list(chain(*masked_cell_token_indices))
            masked_cell_token_labels = [row_instance['tokens'][pos] for pos in masked_cell_token_indices]

            masked_cell_token_indices_list.append(masked_cell_token_indices)
            # masked_cell_token_column_ids_list.append(masked_cell_token_column_ids)
            masked_cell_token_labels_list.append(masked_cell_token_labels)

            if row_id == 0:
                masked_column_token_column_ids = [
                    col_id
                    for col_id in columns_to_mask
                    for token_idx in maskable_column_token_indices[col_id]
                ]

            masked_column_token_indices_list.append(masked_column_token_indices)

        num_masked_column_tokens = len(masked_column_token_indices_list[0])
        assert all(len(mask_list) == num_masked_column_tokens for mask_list in masked_column_token_indices_list)

        max_context_token_to_mask = self.config.max_predictions_per_seq - num_masked_column_tokens

        context_token_indices = (
            list(range(*row_instances[0]['context_span']))[1:]
            if self.config.context_first else
            list(range(*row_instances[0]['context_span']))[:-1]
        )

        num_context_tokens_to_mask = min(
            max_context_token_to_mask,
            max(
                1,
                int(len(context_token_indices) * self.config.masked_context_prob)
            )
        )
        if num_context_tokens_to_mask > 0:
            masked_context_token_indices = sorted(random.sample(context_token_indices, num_context_tokens_to_mask))
        else:
            masked_context_token_indices = []

        masked_token_indices_list = []
        for row_id, row_instance in enumerate(row_instances):
            masked_token_indices_list.append(
                masked_context_token_indices + masked_column_token_indices_list[row_id]
            )

        first_row_tokens = row_instances[0]['tokens']
        masked_context_token_labels = [first_row_tokens[idx] for idx in masked_context_token_indices]
        masked_column_token_labels = [first_row_tokens[idx] for idx in masked_column_token_indices_list[0]]
        masked_token_labels = [first_row_tokens[idx] for idx in masked_token_indices_list[0]]

        for token_relative_idx, token in enumerate(masked_token_labels):
            # 80% of the time, replace with [MASK]
            if random.random() < 0.8:
                masked_token = "[MASK]"
            else:
                # 10% of the time, keep original
                if random.random() < 0.5:
                    masked_token = token
                # 10% of the time, replace with random word
                else:
                    masked_token = random.choice(self.vocab_list)

            # Once we've saved the true label for that token, we can overwrite it with the masked version
            for row_id, row_instance in enumerate(row_instances):
                token_idx = masked_token_indices_list[row_id][token_relative_idx]
                row_instance['tokens'][token_idx] = masked_token

        if (
            self.config.predict_cell_tokens and
            all(
                len(masked_cell_token_indices_list[i]) == 0
                for i in range(len(row_instances))
            )
        ):
            return None

        info = {}
        pretrain_instance = {
            "rows": [
                {
                    'tokens': row_instance['tokens'],
                    'token_ids': self.tokenizer.convert_tokens_to_ids(row_instance['tokens']),
                    'segment_a_length': row_instance['segment_a_length'],
                    'context_span': row_instance['context_span'],
                    'column_token_position_to_column_ids': row_instance['column_token_position_to_column_ids'],
                    'masked_cell_token_positions': masked_cell_token_indices_list[row_id] if self.config.predict_cell_tokens else None,
                    'masked_cell_token_label_ids': self.tokenizer.convert_tokens_to_ids(masked_cell_token_labels_list[row_id]) if self.config.predict_cell_tokens else None
                }
                for row_id, row_instance
                in enumerate(row_instances)
            ],
            'table_size': (len(row_instances), num_maskable_columns),
            'masked_context_token_positions': masked_context_token_indices,
            'masked_context_token_labels': masked_context_token_labels,
            'masked_context_token_label_ids': self.tokenizer.convert_tokens_to_ids(masked_context_token_labels),
            'masked_column_token_column_ids': masked_column_token_column_ids,
            'masked_column_token_labels': masked_column_token_labels,
            'masked_column_token_label_ids': self.tokenizer.convert_tokens_to_ids(masked_column_token_labels),
            "info": info
        }

        assert all(x < num_maskable_columns for x in masked_column_token_column_ids)

        return pretrain_instance