in code/source/bert_preprocessing.py [0:0]
def convert_single_example(tokenizer, example, tag2int, max_seq_length=256):
"""
Converts a single `InputExample` into a single `InputFeatures`.
:param tokenizer: tokenizer created by create_tokenizer_from_hub_module
:param example: example created by convert_text_to_examples
:param tag2int: (dict) dictionary of tags to corresponding integer conversion
:param max_seq_length: (int) length of input example (input size of bert model)
:return: input_ids, input_masks, segment_ids (all three as input for BERT model,
and label_ids (true labels, useful for testing).
At inference, we create placeholder label_ids that we don't reuse (eg '-PAD-')
"""
if isinstance(example, PaddingInputExample):
input_ids = [0] * max_seq_length
input_mask = [0] * max_seq_length
segment_ids = [0] * max_seq_length
label_ids = [0] * max_seq_length
return input_ids, input_mask, segment_ids, label_ids
tokens_a = example.text_a
if len(tokens_a) > max_seq_length-2:
tokens_a = tokens_a[0 : (max_seq_length-2)]
# Token map will be an int -> int mapping between the `orig_tokens` index and
# the `bert_tokens` index.
# bert_tokens == ["[CLS]", "john", "johan", "##son", "'", "s", "house", "[SEP]"]
# orig_to_tok_map == [1, 2, 4, 6]
orig_to_tok_map = []
tokens = []
segment_ids = []
tokens.append("[CLS]")
segment_ids.append(0)
orig_to_tok_map.append(len(tokens)-1)
for token in tokens_a:
orig_to_tok_map.append(len(tokens))
tokens.extend(tokenizer.tokenize(token))
segment_ids.append(0)
tokens.append("[SEP]")
segment_ids.append(0)
orig_to_tok_map.append(len(tokens)-1)
input_ids = tokenizer.convert_tokens_to_ids([tokens[i] for i in orig_to_tok_map])
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
input_mask = [1] * len(input_ids)
label_ids = []
labels = example.label
label_ids.append(0)
label_ids.extend([tag2int[label] for label in labels])
label_ids.append(0)
# Zero-pad up to the sequence length.
while len(input_ids) < max_seq_length:
input_ids.append(0)
input_mask.append(0)
segment_ids.append(0)
label_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
assert len(label_ids) == max_seq_length
return input_ids, input_mask, segment_ids, label_ids