in table_bert/input_formatter.py [0:0]
def get_row_input(self, context: List[str], header: List[Column], row_data: List[Any], trim_long_table=False):
if self.config.context_first:
table_tokens_start_idx = len(context) + 2 # account for [CLS] and [SEP]
# account for [CLS] and [SEP], and the ending [SEP]
max_table_token_length = MAX_BERT_INPUT_LENGTH - len(context) - 2 - 1
else:
table_tokens_start_idx = 1 # account for starting [CLS]
# account for [CLS] and [SEP], and the ending [SEP]
max_table_token_length = MAX_BERT_INPUT_LENGTH - len(context) - 2 - 1
# generate table tokens
row_input_tokens = []
column_token_span_maps = []
column_start_idx = table_tokens_start_idx
for col_id, column in enumerate(header):
value_tokens = row_data[col_id]
truncated_value_tokens = value_tokens[:self.config.max_cell_len]
column_input_tokens, token_span_map = self.get_cell_input(
column,
truncated_value_tokens,
token_offset=column_start_idx
)
column_input_tokens.append(self.config.column_delimiter)
early_stop = False
if trim_long_table:
if len(row_input_tokens) + len(column_input_tokens) > max_table_token_length:
valid_column_input_token_len = max_table_token_length - len(row_input_tokens)
column_input_tokens = column_input_tokens[:valid_column_input_token_len]
end_index = column_start_idx + len(column_input_tokens)
keys_to_delete = []
for key in token_span_map:
if key in {'column_name', 'type', 'value', 'whole_span'}:
span_start_idx, span_end_idx = token_span_map[key]
if span_start_idx < end_index < span_end_idx:
token_span_map[key] = (span_start_idx, end_index)
elif end_index < span_start_idx:
keys_to_delete.append(key)
elif key == 'other_tokens':
old_positions = token_span_map[key]
new_positions = [idx for idx in old_positions if idx < end_index]
if not new_positions:
keys_to_delete.append(key)
for key in keys_to_delete:
del token_span_map[key]
# nothing left, we just skip this cell and break
if len(token_span_map) == 0:
break
early_stop = True
elif len(row_input_tokens) + len(column_input_tokens) == max_table_token_length:
early_stop = True
elif len(row_input_tokens) + len(column_input_tokens) > max_table_token_length:
break
row_input_tokens.extend(column_input_tokens)
column_start_idx = column_start_idx + len(column_input_tokens)
column_token_span_maps.append(token_span_map)
if early_stop: break
# it is possible that the first cell to too long and cannot fit into `max_table_token_length`
# we need to discard this sample
if len(row_input_tokens) == 0:
raise TableTooLongError()
if row_input_tokens[-1] == self.config.column_delimiter:
del row_input_tokens[-1]
if self.config.context_first:
sequence = ['[CLS]'] + context + ['[SEP]'] + row_input_tokens + ['[SEP]']
# segment_ids = [0] * (len(context) + 2) + [1] * (len(row_input_tokens) + 1)
segment_a_length = len(context) + 2
context_span = (0, 1 + len(context))
# context_token_indices = list(range(0, 1 + len(context)))
else:
sequence = ['[CLS]'] + row_input_tokens + ['[SEP]'] + context + ['[SEP]']
# segment_ids = [0] * (len(row_input_tokens) + 2) + [1] * (len(context) + 1)
segment_a_length = len(row_input_tokens) + 2
context_span = (len(row_input_tokens) + 1, len(row_input_tokens) + 1 + 1 + len(context) + 1)
# context_token_indices = list(range(len(row_input_tokens) + 1, len(row_input_tokens) + 1 + 1 + len(context) + 1))
instance = {
'tokens': sequence,
#'token_ids': self.tokenizer.convert_tokens_to_ids(sequence),
'segment_a_length': segment_a_length,
# 'segment_ids': segment_ids,
'column_spans': column_token_span_maps,
'context_length': 1 + len(context), # beginning [CLS]/[SEP] + input question
'context_span': context_span,
# 'context_token_indices': context_token_indices
}
return instance