def collate()

in table_bert/vertical/dataset.py [0:0]


def collate(examples: List[Dict], config: VerticalAttentionTableBertConfig, train=True) -> Dict[str, torch.Tensor]:
    batch_size = len(examples)
    max_sequence_len = max(
        len(row['token_ids'])
        for e in examples
        for row in e['rows']
    )
    max_context_len = max(
        row['context_span'][1] - row['context_span'][0]
        for e in examples
        for row in e['rows']
    )
    max_row_num = max(len(inst['rows']) for inst in examples)
    # max_row_num = max(inst['table_size'][0] for inst in examples)
    # max_column_num = max(inst['table_size'][1] for inst in examples)

    input_ids = np.zeros((batch_size, max_row_num, max_sequence_len), dtype=np.int64)
    mask_array = np.zeros((batch_size, max_row_num, max_sequence_len), dtype=np.float32)
    segment_ids = np.zeros((batch_size, max_row_num, max_sequence_len), dtype=np.int64)

    # table specific tensors
    context_token_positions = np.zeros((batch_size, max_row_num, max_context_len), dtype=np.int)
    context_token_mask = np.zeros((batch_size, max_row_num, max_context_len), dtype=np.bool)

    row_column_nums = []

    # we initialize the mapping with the id of last column as the "garbage collection" entry for reduce ops
    column_token_to_column_id_fill_val = np.iinfo(np.uint16).max
    column_token_position_to_column_ids = np.full((batch_size, max_row_num, max_sequence_len), dtype=np.int, fill_value=column_token_to_column_id_fill_val)
    # column_token_mask = np.zeros((batch_size, max_row_num, max_sequence_len), dtype=np.bool)

    if train:
        # MLM objectives
        masked_context_token_label_ids = np.full((batch_size, max_context_len), dtype=np.int64, fill_value=-1)

        max_column_pred_token_num = max(len(e['masked_column_token_column_ids']) for e in examples)
        masked_column_token_column_ids = np.zeros((batch_size, max_column_pred_token_num), dtype=np.int64)
        masked_column_token_label_ids = np.full((batch_size, max_column_pred_token_num), dtype=np.int64, fill_value=-1)

        # cell token prediction
        predict_cell_tokens = config.predict_cell_tokens
        if predict_cell_tokens:
            max_masked_cell_token_num = max(
                len(row['masked_cell_token_positions'])
                for e in examples
                for row in e['rows']
            )

            masked_cell_token_positions = np.zeros((batch_size, max_row_num, max_masked_cell_token_num), dtype=np.int64)
            masked_cell_token_column_ids = np.zeros((batch_size, max_row_num, max_masked_cell_token_num), dtype=np.int64)
            masked_cell_token_label_ids = np.full((batch_size, max_row_num, max_masked_cell_token_num), dtype=np.int64,
                                                  fill_value=-1)

    for e_id, example in enumerate(examples):
        for row_id, row_inst in enumerate(example['rows']):
            bert_input_seq_length = len(row_inst['token_ids'])

            input_ids[e_id, row_id, :bert_input_seq_length] = row_inst['token_ids']
            mask_array[e_id, row_id, :bert_input_seq_length] = 1
            segment_ids[e_id, row_id, row_inst['segment_a_length']:] = 1

            row_context_token_positions = list(range(
                row_inst['context_span'][0],
                row_inst['context_span'][1]
            ))
            context_token_positions[e_id, row_id, :len(row_context_token_positions)] = row_context_token_positions
            context_token_mask[e_id, row_id, row_context_token_positions] = 1

            row_column_token_position_to_column_ids = row_inst['column_token_position_to_column_ids']
            if not train:
                row_column_token_position_to_column_ids = np.array(row_column_token_position_to_column_ids)

            cur_column_num = row_column_token_position_to_column_ids[row_column_token_position_to_column_ids != column_token_to_column_id_fill_val].max() + 1
            row_column_nums.append(cur_column_num)
            column_token_position_to_column_ids[e_id, row_id, :len(row_column_token_position_to_column_ids)] = row_column_token_position_to_column_ids

            if train and predict_cell_tokens:
                row_masked_cell_token_positions = row_inst['masked_cell_token_positions']
                masked_cell_token_positions[e_id, row_id, :len(row_masked_cell_token_positions)] = row_masked_cell_token_positions
                masked_cell_token_column_ids[e_id, row_id, :len(row_masked_cell_token_positions)] = [
                    row_column_token_position_to_column_ids[pos] for pos in row_masked_cell_token_positions]
                masked_cell_token_label_ids[e_id, row_id, :len(row_masked_cell_token_positions)] = row_inst['masked_cell_token_label_ids']

        # row_num, column_num = example['table_size']
        # table_mask[e_id, :row_num, :column_num] = 1.

        if train:
            masked_context_token_label_ids[e_id, example['masked_context_token_positions']] = example['masked_context_token_label_ids']

            masked_column_token_num = len(example['masked_column_token_column_ids'])
            masked_column_token_column_ids[e_id, :masked_column_token_num] = example['masked_column_token_column_ids']
            masked_column_token_label_ids[e_id, :masked_column_token_num] = example['masked_column_token_label_ids']

    max_column_num = max(row_column_nums)
    table_mask = np.zeros((batch_size, max_row_num, max_column_num), dtype=np.bool)
    global_col_id = 0
    for e_id, example in enumerate(examples):
        for row_id, row_inst in enumerate(example['rows']):
            row_column_num = row_column_nums[global_col_id]
            table_mask[e_id, row_id, :row_column_num] = 1
            global_col_id += 1

    column_token_position_to_column_ids[
        column_token_position_to_column_ids == column_token_to_column_id_fill_val] = max_column_num

    # for table_id in range(len(examples)):
    #     row_num, column_num = examples[table_id]['table_size']
    #     for row_id in range(row_num):
    #         for column_id in range(column_num):
    #             assert column_id in column_token_position_to_column_ids[table_id, row_id]
    #     for masked_col_id in masked_column_token_column_ids[table_id]:
    #         assert masked_col_id < column_num
    #
    # assert column_token_position_to_column_ids.flatten().max() == table_mask.sum(axis=-1).max()

    tensor_dict = {
        'input_ids': torch.tensor(input_ids, dtype=torch.long),
        'segment_ids': torch.tensor(segment_ids, dtype=torch.long),
        'context_token_positions': torch.tensor(context_token_positions, dtype=torch.long),
        'column_token_position_to_column_ids': torch.tensor(column_token_position_to_column_ids, dtype=torch.long),
        'sequence_mask': torch.tensor(mask_array, dtype=torch.float32),
        'context_token_mask': torch.tensor(context_token_mask, dtype=torch.float32),
        'table_mask': torch.tensor(table_mask, dtype=torch.float32),
    }

    if train:
        sample_size = int((masked_context_token_label_ids != -1).sum() + (masked_column_token_label_ids != -1).sum())
        if predict_cell_tokens:
            sample_size += int((masked_cell_token_label_ids != -1).sum())

        tensor_dict.update({
            'masked_context_token_labels': torch.tensor(masked_context_token_label_ids, dtype=torch.long),
            'masked_column_token_column_ids': torch.tensor(masked_column_token_column_ids, dtype=torch.long),
            'masked_column_token_labels': torch.tensor(masked_column_token_label_ids, dtype=torch.long),
            'sample_size': sample_size
        })

        if predict_cell_tokens:
            tensor_dict.update({
                'masked_cell_token_positions': torch.tensor(masked_cell_token_positions, dtype=torch.long),
                'masked_cell_token_column_ids': torch.tensor(masked_cell_token_column_ids, dtype=torch.long),
                'masked_cell_token_labels': torch.tensor(masked_cell_token_label_ids, dtype=torch.long),
            })

    return tensor_dict