def _transform()

in archived/bert_attention_head_view/entry_point/data.py [0:0]


    def _transform(self, *record):
        example = self._toSquadExample(record)

        if not example:
            return None

        padding = self.tokenizer.vocab.padding_token
        if self.do_lookup:
            padding = self.tokenizer.vocab[padding]
        features = []
        query_tokens = self.tokenizer(example.question_text)

        if len(query_tokens) > self.max_query_length:
            query_tokens = query_tokens[0 : self.max_query_length]

        tok_to_orig_index = []
        orig_to_tok_index = []
        all_doc_tokens = []
        for (i, token) in enumerate(example.doc_tokens):
            orig_to_tok_index.append(len(all_doc_tokens))
            sub_tokens = self.tokenizer(token)
            for sub_token in sub_tokens:
                tok_to_orig_index.append(i)
                all_doc_tokens.append(sub_token)

        tok_start_position = None
        tok_end_position = None
        if self.is_training and example.is_impossible:
            tok_start_position = -1
            tok_end_position = -1
        if self.is_training and not example.is_impossible:
            tok_start_position = orig_to_tok_index[example.start_position]
            if example.end_position < len(example.doc_tokens) - 1:
                tok_end_position = orig_to_tok_index[example.end_position + 1] - 1
            else:
                tok_end_position = len(all_doc_tokens) - 1
            (tok_start_position, tok_end_position) = _improve_answer_span(
                all_doc_tokens,
                tok_start_position,
                tok_end_position,
                self.tokenizer,
                example.orig_answer_text,
            )

        # The -3 accounts for [CLS], [SEP] and [SEP]
        max_tokens_for_doc = self.max_seq_length - len(query_tokens) - 3

        # We can have documents that are longer than the maximum sequence length.
        # To deal with this we do a sliding window approach, where we take chunks
        # of the up to our max length with a stride of `doc_stride`.
        _DocSpan = collections.namedtuple(  # pylint: disable=invalid-name
            "DocSpan", ["start", "length"]
        )
        doc_spans = []
        start_offset = 0
        while start_offset < len(all_doc_tokens):
            length = len(all_doc_tokens) - start_offset
            if length > max_tokens_for_doc:
                length = max_tokens_for_doc
            doc_spans.append(_DocSpan(start=start_offset, length=length))
            if start_offset + length == len(all_doc_tokens):
                break
            start_offset += min(length, self.doc_stride)

        for (doc_span_index, doc_span) in enumerate(doc_spans):
            tokens = []
            token_to_orig_map = {}
            token_is_max_context = {}
            segment_ids = []
            tokens.append(self.tokenizer.vocab.cls_token)
            segment_ids.append(0)
            for token in query_tokens:
                tokens.append(token)
                segment_ids.append(0)
            tokens.append(self.tokenizer.vocab.sep_token)
            segment_ids.append(0)

            for i in range(doc_span.length):
                split_token_index = doc_span.start + i
                token_to_orig_map[len(tokens)] = tok_to_orig_index[split_token_index]

                is_max_context = _check_is_max_context(doc_spans, doc_span_index, split_token_index)
                token_is_max_context[len(tokens)] = is_max_context
                tokens.append(all_doc_tokens[split_token_index])
                segment_ids.append(1)
            tokens.append(self.tokenizer.vocab.sep_token)
            segment_ids.append(1)

            if self.do_lookup:
                input_ids = self.tokenizer.convert_tokens_to_ids(tokens)
            else:
                input_ids = tokens

            # The mask has 1 for real tokens and 0 for padding tokens. Only real
            # tokens are attended to.
            valid_length = len(input_ids)

            # Zero-pad up to the sequence length.
            if self.is_pad:
                while len(input_ids) < self.max_seq_length:
                    input_ids.append(padding)
                    segment_ids.append(padding)

                assert len(input_ids) == self.max_seq_length
                assert len(segment_ids) == self.max_seq_length

            start_position = 0
            end_position = 0
            if self.is_training and not example.is_impossible:
                # For training, if our document chunk does not contain an annotation
                # we throw it out, since there is nothing to predict.
                doc_start = doc_span.start
                doc_end = doc_span.start + doc_span.length - 1
                out_of_span = False
                if not (tok_start_position >= doc_start and tok_end_position <= doc_end):
                    out_of_span = True
                if out_of_span:
                    start_position = 0
                    end_position = 0
                else:
                    doc_offset = len(query_tokens) + 2
                    start_position = tok_start_position - doc_start + doc_offset
                    end_position = tok_end_position - doc_start + doc_offset

            if self.is_training and example.is_impossible:
                start_position = 0
                end_position = 0

            features.append(
                SQuADFeature(
                    example_id=example.example_id,
                    qas_id=example.qas_id,
                    doc_tokens=example.doc_tokens,
                    doc_span_index=doc_span_index,
                    tokens=tokens,
                    token_to_orig_map=token_to_orig_map,
                    token_is_max_context=token_is_max_context,
                    input_ids=input_ids,
                    valid_length=valid_length,
                    segment_ids=segment_ids,
                    start_position=start_position,
                    end_position=end_position,
                    is_impossible=example.is_impossible,
                )
            )

        return features