in archived/bert_attention_head_view/entry_point/data.py [0:0]
def _transform(self, *record):
example = self._toSquadExample(record)
if not example:
return None
padding = self.tokenizer.vocab.padding_token
if self.do_lookup:
padding = self.tokenizer.vocab[padding]
features = []
query_tokens = self.tokenizer(example.question_text)
if len(query_tokens) > self.max_query_length:
query_tokens = query_tokens[0 : self.max_query_length]
tok_to_orig_index = []
orig_to_tok_index = []
all_doc_tokens = []
for (i, token) in enumerate(example.doc_tokens):
orig_to_tok_index.append(len(all_doc_tokens))
sub_tokens = self.tokenizer(token)
for sub_token in sub_tokens:
tok_to_orig_index.append(i)
all_doc_tokens.append(sub_token)
tok_start_position = None
tok_end_position = None
if self.is_training and example.is_impossible:
tok_start_position = -1
tok_end_position = -1
if self.is_training and not example.is_impossible:
tok_start_position = orig_to_tok_index[example.start_position]
if example.end_position < len(example.doc_tokens) - 1:
tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
else:
tok_end_position = len(all_doc_tokens) - 1
(tok_start_position, tok_end_position) = _improve_answer_span(
all_doc_tokens,
tok_start_position,
tok_end_position,
self.tokenizer,
example.orig_answer_text,
)
# The -3 accounts for [CLS], [SEP] and [SEP]
max_tokens_for_doc = self.max_seq_length - len(query_tokens) - 3
# We can have documents that are longer than the maximum sequence length.
# To deal with this we do a sliding window approach, where we take chunks
# of the up to our max length with a stride of `doc_stride`.
_DocSpan = collections.namedtuple( # pylint: disable=invalid-name
"DocSpan", ["start", "length"]
)
doc_spans = []
start_offset = 0
while start_offset < len(all_doc_tokens):
length = len(all_doc_tokens) - start_offset
if length > max_tokens_for_doc:
length = max_tokens_for_doc
doc_spans.append(_DocSpan(start=start_offset, length=length))
if start_offset + length == len(all_doc_tokens):
break
start_offset += min(length, self.doc_stride)
for (doc_span_index, doc_span) in enumerate(doc_spans):
tokens = []
token_to_orig_map = {}
token_is_max_context = {}
segment_ids = []
tokens.append(self.tokenizer.vocab.cls_token)
segment_ids.append(0)
for token in query_tokens:
tokens.append(token)
segment_ids.append(0)
tokens.append(self.tokenizer.vocab.sep_token)
segment_ids.append(0)
for i in range(doc_span.length):
split_token_index = doc_span.start + i
token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]
is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
token_is_max_context[len(tokens)] = is_max_context
tokens.append(all_doc_tokens[split_token_index])
segment_ids.append(1)
tokens.append(self.tokenizer.vocab.sep_token)
segment_ids.append(1)
if self.do_lookup:
input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
else:
input_ids = tokens
# The mask has 1 for real tokens and 0 for padding tokens. Only real
# tokens are attended to.
valid_length = len(input_ids)
# Zero-pad up to the sequence length.
if self.is_pad:
while len(input_ids) < self.max_seq_length:
input_ids.append(padding)
segment_ids.append(padding)
assert len(input_ids) == self.max_seq_length
assert len(segment_ids) == self.max_seq_length
start_position = 0
end_position = 0
if self.is_training and not example.is_impossible:
# For training, if our document chunk does not contain an annotation
# we throw it out, since there is nothing to predict.
doc_start = doc_span.start
doc_end = doc_span.start + doc_span.length - 1
out_of_span = False
if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
out_of_span = True
if out_of_span:
start_position = 0
end_position = 0
else:
doc_offset = len(query_tokens) + 2
start_position = tok_start_position - doc_start + doc_offset
end_position = tok_end_position - doc_start + doc_offset
if self.is_training and example.is_impossible:
start_position = 0
end_position = 0
features.append(
SQuADFeature(
example_id=example.example_id,
qas_id=example.qas_id,
doc_tokens=example.doc_tokens,
doc_span_index=doc_span_index,
tokens=tokens,
token_to_orig_map=token_to_orig_map,
token_is_max_context=token_is_max_context,
input_ids=input_ids,
valid_length=valid_length,
segment_ids=segment_ids,
start_position=start_position,
end_position=end_position,
is_impossible=example.is_impossible,
)
)
return features