def _find_max_sequence_length()

in eland/ml/pytorch/transformers.py [0:0]


    def _find_max_sequence_length(self) -> int:
        # Sometimes the max_... values are present but contain
        # a random or very large value.
        REASONABLE_MAX_LENGTH = 8192
        max_len = getattr(self._tokenizer, "model_max_length", None)
        if max_len is not None and max_len <= REASONABLE_MAX_LENGTH:
            return int(max_len)

        max_sizes = getattr(self._tokenizer, "max_model_input_sizes", dict())
        max_len = max_sizes.get(self._model_id)
        if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
            return int(max_len)

        if max_sizes:
            # The model id wasn't found in the max sizes dict but
            # if all the values correspond then take that value
            sizes = {size for size in max_sizes.values()}
            if len(sizes) == 1:
                max_len = sizes.pop()
                if max_len is not None and max_len < REASONABLE_MAX_LENGTH:
                    return int(max_len)

        if isinstance(self._tokenizer, BertTokenizer):
            return 512

        raise UnknownModelInputSizeError("Cannot determine model max input length")