in table_bert/vertical/dataset.py [0:0]
def collate(examples: List[Dict], config: VerticalAttentionTableBertConfig, train=True) -> Dict[str, torch.Tensor]:
batch_size = len(examples)
max_sequence_len = max(
len(row['token_ids'])
for e in examples
for row in e['rows']
)
max_context_len = max(
row['context_span'][1] - row['context_span'][0]
for e in examples
for row in e['rows']
)
max_row_num = max(len(inst['rows']) for inst in examples)
# max_row_num = max(inst['table_size'][0] for inst in examples)
# max_column_num = max(inst['table_size'][1] for inst in examples)
input_ids = np.zeros((batch_size, max_row_num, max_sequence_len), dtype=np.int64)
mask_array = np.zeros((batch_size, max_row_num, max_sequence_len), dtype=np.float32)
segment_ids = np.zeros((batch_size, max_row_num, max_sequence_len), dtype=np.int64)
# table specific tensors
context_token_positions = np.zeros((batch_size, max_row_num, max_context_len), dtype=np.int)
context_token_mask = np.zeros((batch_size, max_row_num, max_context_len), dtype=np.bool)
row_column_nums = []
# we initialize the mapping with the id of last column as the "garbage collection" entry for reduce ops
column_token_to_column_id_fill_val = np.iinfo(np.uint16).max
column_token_position_to_column_ids = np.full((batch_size, max_row_num, max_sequence_len), dtype=np.int, fill_value=column_token_to_column_id_fill_val)
# column_token_mask = np.zeros((batch_size, max_row_num, max_sequence_len), dtype=np.bool)
if train:
# MLM objectives
masked_context_token_label_ids = np.full((batch_size, max_context_len), dtype=np.int64, fill_value=-1)
max_column_pred_token_num = max(len(e['masked_column_token_column_ids']) for e in examples)
masked_column_token_column_ids = np.zeros((batch_size, max_column_pred_token_num), dtype=np.int64)
masked_column_token_label_ids = np.full((batch_size, max_column_pred_token_num), dtype=np.int64, fill_value=-1)
# cell token prediction
predict_cell_tokens = config.predict_cell_tokens
if predict_cell_tokens:
max_masked_cell_token_num = max(
len(row['masked_cell_token_positions'])
for e in examples
for row in e['rows']
)
masked_cell_token_positions = np.zeros((batch_size, max_row_num, max_masked_cell_token_num), dtype=np.int64)
masked_cell_token_column_ids = np.zeros((batch_size, max_row_num, max_masked_cell_token_num), dtype=np.int64)
masked_cell_token_label_ids = np.full((batch_size, max_row_num, max_masked_cell_token_num), dtype=np.int64,
fill_value=-1)
for e_id, example in enumerate(examples):
for row_id, row_inst in enumerate(example['rows']):
bert_input_seq_length = len(row_inst['token_ids'])
input_ids[e_id, row_id, :bert_input_seq_length] = row_inst['token_ids']
mask_array[e_id, row_id, :bert_input_seq_length] = 1
segment_ids[e_id, row_id, row_inst['segment_a_length']:] = 1
row_context_token_positions = list(range(
row_inst['context_span'][0],
row_inst['context_span'][1]
))
context_token_positions[e_id, row_id, :len(row_context_token_positions)] = row_context_token_positions
context_token_mask[e_id, row_id, row_context_token_positions] = 1
row_column_token_position_to_column_ids = row_inst['column_token_position_to_column_ids']
if not train:
row_column_token_position_to_column_ids = np.array(row_column_token_position_to_column_ids)
cur_column_num = row_column_token_position_to_column_ids[row_column_token_position_to_column_ids != column_token_to_column_id_fill_val].max() + 1
row_column_nums.append(cur_column_num)
column_token_position_to_column_ids[e_id, row_id, :len(row_column_token_position_to_column_ids)] = row_column_token_position_to_column_ids
if train and predict_cell_tokens:
row_masked_cell_token_positions = row_inst['masked_cell_token_positions']
masked_cell_token_positions[e_id, row_id, :len(row_masked_cell_token_positions)] = row_masked_cell_token_positions
masked_cell_token_column_ids[e_id, row_id, :len(row_masked_cell_token_positions)] = [
row_column_token_position_to_column_ids[pos] for pos in row_masked_cell_token_positions]
masked_cell_token_label_ids[e_id, row_id, :len(row_masked_cell_token_positions)] = row_inst['masked_cell_token_label_ids']
# row_num, column_num = example['table_size']
# table_mask[e_id, :row_num, :column_num] = 1.
if train:
masked_context_token_label_ids[e_id, example['masked_context_token_positions']] = example['masked_context_token_label_ids']
masked_column_token_num = len(example['masked_column_token_column_ids'])
masked_column_token_column_ids[e_id, :masked_column_token_num] = example['masked_column_token_column_ids']
masked_column_token_label_ids[e_id, :masked_column_token_num] = example['masked_column_token_label_ids']
max_column_num = max(row_column_nums)
table_mask = np.zeros((batch_size, max_row_num, max_column_num), dtype=np.bool)
global_col_id = 0
for e_id, example in enumerate(examples):
for row_id, row_inst in enumerate(example['rows']):
row_column_num = row_column_nums[global_col_id]
table_mask[e_id, row_id, :row_column_num] = 1
global_col_id += 1
column_token_position_to_column_ids[
column_token_position_to_column_ids == column_token_to_column_id_fill_val] = max_column_num
# for table_id in range(len(examples)):
# row_num, column_num = examples[table_id]['table_size']
# for row_id in range(row_num):
# for column_id in range(column_num):
# assert column_id in column_token_position_to_column_ids[table_id, row_id]
# for masked_col_id in masked_column_token_column_ids[table_id]:
# assert masked_col_id < column_num
#
# assert column_token_position_to_column_ids.flatten().max() == table_mask.sum(axis=-1).max()
tensor_dict = {
'input_ids': torch.tensor(input_ids, dtype=torch.long),
'segment_ids': torch.tensor(segment_ids, dtype=torch.long),
'context_token_positions': torch.tensor(context_token_positions, dtype=torch.long),
'column_token_position_to_column_ids': torch.tensor(column_token_position_to_column_ids, dtype=torch.long),
'sequence_mask': torch.tensor(mask_array, dtype=torch.float32),
'context_token_mask': torch.tensor(context_token_mask, dtype=torch.float32),
'table_mask': torch.tensor(table_mask, dtype=torch.float32),
}
if train:
sample_size = int((masked_context_token_label_ids != -1).sum() + (masked_column_token_label_ids != -1).sum())
if predict_cell_tokens:
sample_size += int((masked_cell_token_label_ids != -1).sum())
tensor_dict.update({
'masked_context_token_labels': torch.tensor(masked_context_token_label_ids, dtype=torch.long),
'masked_column_token_column_ids': torch.tensor(masked_column_token_column_ids, dtype=torch.long),
'masked_column_token_labels': torch.tensor(masked_column_token_label_ids, dtype=torch.long),
'sample_size': sample_size
})
if predict_cell_tokens:
tensor_dict.update({
'masked_cell_token_positions': torch.tensor(masked_cell_token_positions, dtype=torch.long),
'masked_cell_token_column_ids': torch.tensor(masked_cell_token_column_ids, dtype=torch.long),
'masked_cell_token_labels': torch.tensor(masked_cell_token_label_ids, dtype=torch.long),
})
return tensor_dict