sockeye/data_io.py [1011:1328]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    return train_iter, validation_iter, config_data, data_info


def get_scoring_data_iters(sources: List[str],
                           targets: List[str],
                           source_vocabs: List[vocab.Vocab],
                           target_vocabs: List[vocab.Vocab],
                           batch_size: int,
                           max_seq_len_source: int,
                           max_seq_len_target: int) -> 'BaseParallelSampleIter':
    """
    Returns a data iterator for scoring. The iterator loads data on demand,
    batch by batch, and does not skip any lines. Lines that are too long
    are truncated.

    :param sources: Path to source training data (with optional factor data paths).
    :param targets: Path to target training data (with optional factor data paths).
    :param source_vocabs: Source vocabulary and optional factor vocabularies.
    :param target_vocabs: Target vocabulary and optional factor vocabularies.
    :param batch_size: Batch size.
    :param max_seq_len_source: Maximum source sequence length.
    :param max_seq_len_target: Maximum target sequence length.
    :return: The scoring data iterator.
    """
    logger.info("==============================")
    logger.info("Creating scoring data iterator")
    logger.info("==============================")

    # One bucket to hold them all,
    bucket = (max_seq_len_source, max_seq_len_target)

    # ...One loader to raise them,
    data_loader = RawParallelDatasetLoader(buckets=[bucket],
                                           eos_id=C.EOS_ID,
                                           pad_id=C.PAD_ID,
                                           skip_blanks=False)

    # ...one iterator to traverse them all,
    scoring_iter = BatchedRawParallelSampleIter(data_loader=data_loader,
                                                sources=sources,
                                                targets=targets,
                                                source_vocabs=source_vocabs,
                                                target_vocabs=target_vocabs,
                                                bucket=bucket,
                                                batch_size=batch_size,
                                                max_lens=(max_seq_len_source, max_seq_len_target),
                                                num_source_factors=len(sources),
                                                num_target_factors=len(targets))

    # and with the model appraise them.
    return scoring_iter


@dataclass
class LengthStatistics(config.Config):
    num_sents: int
    length_ratio_mean: float
    length_ratio_std: float


@dataclass
class DataStatistics(config.Config):
    num_sents: int
    num_discarded: int
    num_tokens_source: int
    num_tokens_target: int
    num_unks_source: int
    num_unks_target: int
    max_observed_len_source: int
    max_observed_len_target: int
    size_vocab_source: int
    size_vocab_target: int
    length_ratio_mean: float
    length_ratio_std: float
    buckets: List[Tuple[int, int]]
    num_sents_per_bucket: List[int]
    average_len_target_per_bucket: List[Optional[float]]
    length_ratio_stats_per_bucket: Optional[List[Tuple[Optional[float], Optional[float]]]] = None

    def log(self, bucket_batch_sizes: Optional[List[BucketBatchSize]] = None):
        logger.info("Tokens: source %d target %d", self.num_tokens_source, self.num_tokens_target)
        logger.info("Number of <unk> tokens: source %d target %d", self.num_unks_source, self.num_unks_target)
        if self.num_tokens_source > 0 and self.num_tokens_target > 0:
            logger.info("Vocabulary coverage: source %.0f%% target %.0f%%",
                        (1 - self.num_unks_source / self.num_tokens_source) * 100,
                        (1 - self.num_unks_target / self.num_tokens_target) * 100)
        logger.info("%d sequences across %d buckets", self.num_sents, len(self.num_sents_per_bucket))
        logger.info("%d sequences did not fit into buckets and were discarded", self.num_discarded)
        if bucket_batch_sizes is not None:
            describe_data_and_buckets(self, bucket_batch_sizes)


def describe_data_and_buckets(data_statistics: DataStatistics, bucket_batch_sizes: List[BucketBatchSize]):
    """
    Describes statistics across buckets
    """
    check_condition(len(bucket_batch_sizes) == len(data_statistics.buckets),
                    "Number of bucket batch sizes (%d) does not match number of buckets in statistics (%d)."
                    % (len(bucket_batch_sizes), len(data_statistics.buckets)))
    for bucket_batch_size, num_seq, (lr_mean, lr_std) in zip(bucket_batch_sizes,
                                                             data_statistics.num_sents_per_bucket,
                                                             data_statistics.length_ratio_stats_per_bucket):
        if num_seq > 0:
            logger.info("Bucket %s: %d samples in %d batches of %d, ~%.1f target tokens/batch, "
                        "trg/src length ratio: %.2f (+-%.2f)",
                        bucket_batch_size.bucket,
                        num_seq,
                        math.ceil(num_seq / bucket_batch_size.batch_size),
                        bucket_batch_size.batch_size,
                        bucket_batch_size.average_target_words_per_batch,
                        lr_mean, lr_std)


@dataclass
class DataInfo(config.Config):
    """
    Stores training data information that is not relevant for inference.
    """
    sources: List[str]
    targets: List[str]
    source_vocabs: List[Optional[str]]
    target_vocabs: List[Optional[str]]
    shared_vocab: bool
    num_shards: int


@dataclass
class DataConfig(config.Config):
    """
    Stores data statistics relevant for inference.
    """
    data_statistics: DataStatistics
    max_seq_len_source: int
    max_seq_len_target: int
    num_source_factors: int
    num_target_factors: int


def read_content(path: str, limit: Optional[int] = None) -> Iterator[List[str]]:
    """
    Returns a list of tokens for each line in path up to a limit.

    :param path: Path to files containing sentences.
    :param limit: How many lines to read from path.
    :return: Iterator over lists of words.
    """
    with smart_open(path) as indata:
        for i, line in enumerate(indata):
            if limit is not None and i == limit:
                break
            yield list(get_tokens(line))


def tokens2ids(tokens: Iterable[str], vocab: Dict[str, int]) -> List[int]:
    """
    Returns sequence of integer ids given a sequence of tokens and vocab.

    :param tokens: List of string tokens.
    :param vocab: Vocabulary (containing UNK symbol).
    :return: List of word ids.
    """
    return [vocab.get(w, vocab[C.UNK_SYMBOL]) for w in tokens]


def strids2ids(tokens: Iterable[str]) -> List[int]:
    """
    Returns sequence of integer ids given a sequence of string ids.

    :param tokens: List of integer tokens.
    :return: List of word ids.
    """
    return list(map(int, tokens))


def ids2tokens(token_ids: Iterable[int],
               vocab_inv: Dict[int, str],
               exclude_set: Set[int]) -> Iterator[str]:
    """
    Transforms a list of token IDs into a list of words, excluding any IDs in `exclude_set`.

    :param token_ids: The list of token IDs.
    :param vocab_inv: The inverse vocabulary.
    :param exclude_set: The list of token IDs to exclude.
    :return: The list of words.
    """
    tokens = (vocab_inv[token] for token in token_ids)
    return (tok for token_id, tok in zip(token_ids, tokens) if token_id not in exclude_set)


class SequenceReader:
    """
    Reads sequence samples from path and (optionally) creates integer id sequences.
    Streams from disk, instead of loading all samples into memory.
    If vocab is None, the sequences in path are assumed to be integers coded as strings.
    Empty sequences are yielded as None.

    :param path: Path to read data from.
    :param vocabulary: Optional mapping from strings to integer ids.
    :param add_bos: Whether to add Beginning-Of-Sentence (BOS) symbol.
    :param limit: Read limit.
    """

    def __init__(self,
                 path: str,
                 vocabulary: Optional[vocab.Vocab] = None,
                 add_bos: bool = False,
                 add_eos: bool = False,
                 limit: Optional[int] = None) -> None:
        self.path = path
        self.vocab = vocabulary
        self.bos_id = None
        self.eos_id = None
        if vocabulary is not None:
            assert vocab.is_valid_vocab(vocabulary)
            self.bos_id = C.BOS_ID
            self.eos_id = C.EOS_ID
        else:
            check_condition(not add_bos and not add_eos, "Adding a BOS or EOS symbol requires a vocabulary")
        self.add_bos = add_bos
        self.add_eos = add_eos
        self.limit = limit

    def __iter__(self):
        for tokens in read_content(self.path, self.limit):
            if self.vocab is not None:
                sequence = tokens2ids(tokens, self.vocab)
            else:
                sequence = strids2ids(tokens)
            if len(sequence) == 0:
                yield None
                continue
            if self.add_bos:
                sequence.insert(0, self.bos_id)
            if self.add_eos:
                sequence.append(self.eos_id)
            yield sequence


def create_sequence_readers(sources: List[str], targets: List[str],
                            vocab_sources: List[vocab.Vocab],
                            vocab_targets: List[vocab.Vocab]) -> Tuple[List[SequenceReader], List[SequenceReader]]:
    """
    Create source readers with EOS and target readers with BOS.

    :param sources: The file names of source data and factors.
    :param targets: The file name of the target data and factors.
    :param vocab_sources: The source vocabularies.
    :param vocab_targets: The target vocabularies.
    :return: The source sequence readers and the target reader.
    """
    source_sequence_readers = [SequenceReader(source, vocab, add_eos=True) for source, vocab in
                                zip(sources, vocab_sources)]
    target_sequence_readers = [SequenceReader(target, vocab, add_bos=True) for target, vocab in
                                zip(targets, vocab_targets)]
    return source_sequence_readers, target_sequence_readers


def parallel_iter(source_iterables: Sequence[Iterable[Optional[Any]]],
                  target_iterables: Sequence[Iterable[Optional[Any]]],
                  skip_blanks: bool = True,
                  check_token_parallel: bool = True):
    """
    Creates iterators over parallel iterables by calling iter() on the iterables
    and chaining to parallel_iterate(). The purpose of the separation is to allow
    the caller to save iterator state between calls, if desired.

    :param source_iterables: A list of source iterables.
    :param target_iterables: A target iterable.
    :param skip_blanks: Whether to skip empty target lines.
    :param check_token_parallel: Whether to check if the tokens are parallel or not.
    :return: Iterators over sources and target.
    """
    source_iterators = [iter(s) for s in source_iterables]
    target_iterators = [iter(t) for t in target_iterables]
    return parallel_iterate(source_iterators, target_iterators, skip_blanks, check_token_parallel)


def parallel_iterate(source_iterators: Sequence[Iterator[Optional[Any]]],
                     target_iterators: Sequence[Iterator[Optional[Any]]],
                     skip_blanks: bool = True,
                     check_token_parallel: bool = True):
    """
    Yields parallel source(s), target sequences from iterables.
    Checks for token parallelism in source sequences.
    Skips pairs where element in at least one iterable is None.
    Checks that all iterables have the same number of elements.
    Can optionally continue from an already-begun iterator.

    :param source_iterators: A list of source iterators.
    :param target_iterators: A list of source iterators.
    :param skip_blanks: Whether to skip empty target lines.
    :param check_token_parallel: Whether to check if the tokens are parallel or not.
    :return: Iterators over sources and target.
    """
    num_skipped = 0
    while True:
        try:
            sources = [next(source_iter) for source_iter in source_iterators]
            targets = [next(target_iter) for target_iter in target_iterators]
        except StopIteration:
            break
        if skip_blanks and (any((s is None for s in sources)) or any((t is None for t in targets))):
            num_skipped += 1
            continue
        if check_token_parallel:
            check_condition(are_none(sources) or are_token_parallel(sources),
                            "Source sequences are not token-parallel: %s" % (str(sources)))
            check_condition(are_none(targets) or are_token_parallel(targets),
                            "Target sequences are not token-parallel: %s" % (str(targets)))
        yield sources, targets

    if num_skipped > 0:
        logger.warning("Parallel reading of sequences skipped %d elements", num_skipped)

    check_condition(
        all(next(cast(Iterator, s), None) is None for s in source_iterators) and \
        all(next(cast(Iterator, t), None) is None for t in target_iterators),
        "Different number of lines in source(s) and target(s) iterables.")
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



sockeye/data_io_pt.py [990:1307]:
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -
    return train_iter, validation_iter, config_data, data_info


def get_scoring_data_iters(sources: List[str],
                           targets: List[str],
                           source_vocabs: List[vocab.Vocab],
                           target_vocabs: List[vocab.Vocab],
                           batch_size: int,
                           max_seq_len_source: int,
                           max_seq_len_target: int) -> 'BaseParallelSampleIter':
    """
    Returns a data iterator for scoring. The iterator loads data on demand,
    batch by batch, and does not skip any lines. Lines that are too long
    are truncated.

    :param sources: Path to source training data (with optional factor data paths).
    :param targets: Path to target training data (with optional factor data paths).
    :param source_vocabs: Source vocabulary and optional factor vocabularies.
    :param target_vocabs: Target vocabulary and optional factor vocabularies.
    :param batch_size: Batch size.
    :param max_seq_len_source: Maximum source sequence length.
    :param max_seq_len_target: Maximum target sequence length.
    :return: The scoring data iterator.
    """
    logger.info("==============================")
    logger.info("Creating scoring data iterator")
    logger.info("==============================")

    # One bucket to hold them all,
    bucket = (max_seq_len_source, max_seq_len_target)

    # ...One loader to raise them,
    data_loader = RawParallelDatasetLoader(buckets=[bucket],
                                           eos_id=C.EOS_ID,
                                           pad_id=C.PAD_ID,
                                           skip_blanks=False)

    # ...one iterator to traverse them all,
    scoring_iter = BatchedRawParallelSampleIter(data_loader=data_loader,
                                                sources=sources,
                                                targets=targets,
                                                source_vocabs=source_vocabs,
                                                target_vocabs=target_vocabs,
                                                bucket=bucket,
                                                batch_size=batch_size,
                                                max_lens=(max_seq_len_source, max_seq_len_target),
                                                num_source_factors=len(sources),
                                                num_target_factors=len(targets))

    # and with the model appraise them.
    return scoring_iter


@dataclass
class LengthStatistics(config.Config):
    num_sents: int
    length_ratio_mean: float
    length_ratio_std: float


@dataclass
class DataStatistics(config.Config):
    num_sents: int
    num_discarded: int
    num_tokens_source: int
    num_tokens_target: int
    num_unks_source: int
    num_unks_target: int
    max_observed_len_source: int
    max_observed_len_target: int
    size_vocab_source: int
    size_vocab_target: int
    length_ratio_mean: float
    length_ratio_std: float
    buckets: List[Tuple[int, int]]
    num_sents_per_bucket: List[int]
    average_len_target_per_bucket: List[Optional[float]]
    length_ratio_stats_per_bucket: Optional[List[Tuple[Optional[float], Optional[float]]]] = None

    def log(self, bucket_batch_sizes: Optional[List[BucketBatchSize]] = None):
        logger.info("Tokens: source %d target %d", self.num_tokens_source, self.num_tokens_target)
        logger.info("Number of <unk> tokens: source %d target %d", self.num_unks_source, self.num_unks_target)
        if self.num_tokens_source > 0 and self.num_tokens_target > 0:
            logger.info("Vocabulary coverage: source %.0f%% target %.0f%%",
                        (1 - self.num_unks_source / self.num_tokens_source) * 100,
                        (1 - self.num_unks_target / self.num_tokens_target) * 100)
        logger.info("%d sequences across %d buckets", self.num_sents, len(self.num_sents_per_bucket))
        logger.info("%d sequences did not fit into buckets and were discarded", self.num_discarded)
        if bucket_batch_sizes is not None:
            describe_data_and_buckets(self, bucket_batch_sizes)


def describe_data_and_buckets(data_statistics: DataStatistics, bucket_batch_sizes: List[BucketBatchSize]):
    """
    Describes statistics across buckets
    """
    check_condition(len(bucket_batch_sizes) == len(data_statistics.buckets),
                    "Number of bucket batch sizes (%d) does not match number of buckets in statistics (%d)."
                    % (len(bucket_batch_sizes), len(data_statistics.buckets)))
    for bucket_batch_size, num_seq, (lr_mean, lr_std) in zip(bucket_batch_sizes,
                                                             data_statistics.num_sents_per_bucket,
                                                             data_statistics.length_ratio_stats_per_bucket):
        if num_seq > 0:
            logger.info("Bucket %s: %d samples in %d batches of %d, ~%.1f target tokens/batch, "
                        "trg/src length ratio: %.2f (+-%.2f)",
                        bucket_batch_size.bucket,
                        num_seq,
                        math.ceil(num_seq / bucket_batch_size.batch_size),
                        bucket_batch_size.batch_size,
                        bucket_batch_size.average_target_words_per_batch,
                        lr_mean, lr_std)


@dataclass
class DataInfo(config.Config):
    """
    Stores training data information that is not relevant for inference.
    """
    sources: List[str]
    targets: List[str]
    source_vocabs: List[Optional[str]]
    target_vocabs: List[Optional[str]]
    shared_vocab: bool
    num_shards: int


@dataclass
class DataConfig(config.Config):
    """
    Stores data statistics relevant for inference.
    """
    data_statistics: DataStatistics
    max_seq_len_source: int
    max_seq_len_target: int
    num_source_factors: int
    num_target_factors: int


def read_content(path: str, limit: Optional[int] = None) -> Iterator[List[str]]:
    """
    Returns a list of tokens for each line in path up to a limit.

    :param path: Path to files containing sentences.
    :param limit: How many lines to read from path.
    :return: Iterator over lists of words.
    """
    with smart_open(path) as indata:
        for i, line in enumerate(indata):
            if limit is not None and i == limit:
                break
            yield list(get_tokens(line))


def tokens2ids(tokens: Iterable[str], vocab: Dict[str, int]) -> List[int]:
    """
    Returns sequence of integer ids given a sequence of tokens and vocab.

    :param tokens: List of string tokens.
    :param vocab: Vocabulary (containing UNK symbol).
    :return: List of word ids.
    """
    return [vocab.get(w, vocab[C.UNK_SYMBOL]) for w in tokens]


def strids2ids(tokens: Iterable[str]) -> List[int]:
    """
    Returns sequence of integer ids given a sequence of string ids.

    :param tokens: List of integer tokens.
    :return: List of word ids.
    """
    return list(map(int, tokens))


def ids2tokens(token_ids: Iterable[int],
               vocab_inv: Dict[int, str],
               exclude_set: Set[int]) -> Iterator[str]:
    """
    Transforms a list of token IDs into a list of words, excluding any IDs in `exclude_set`.

    :param token_ids: The list of token IDs.
    :param vocab_inv: The inverse vocabulary.
    :param exclude_set: The list of token IDs to exclude.
    :return: The list of words.
    """
    tokens = (vocab_inv[token] for token in token_ids)
    return (tok for token_id, tok in zip(token_ids, tokens) if token_id not in exclude_set)


class SequenceReader:
    """
    Reads sequence samples from path and (optionally) creates integer id sequences.
    Streams from disk, instead of loading all samples into memory.
    If vocab is None, the sequences in path are assumed to be integers coded as strings.
    Empty sequences are yielded as None.

    :param path: Path to read data from.
    :param vocabulary: Optional mapping from strings to integer ids.
    :param add_bos: Whether to add Beginning-Of-Sentence (BOS) symbol.
    :param limit: Read limit.
    """

    def __init__(self,
                 path: str,
                 vocabulary: Optional[vocab.Vocab] = None,
                 add_bos: bool = False,
                 add_eos: bool = False,
                 limit: Optional[int] = None) -> None:
        self.path = path
        self.vocab = vocabulary
        self.bos_id = None
        self.eos_id = None
        if vocabulary is not None:
            assert vocab.is_valid_vocab(vocabulary)
            self.bos_id = C.BOS_ID
            self.eos_id = C.EOS_ID
        else:
            check_condition(not add_bos and not add_eos, "Adding a BOS or EOS symbol requires a vocabulary")
        self.add_bos = add_bos
        self.add_eos = add_eos
        self.limit = limit

    def __iter__(self):
        for tokens in read_content(self.path, self.limit):
            if self.vocab is not None:
                sequence = tokens2ids(tokens, self.vocab)
            else:
                sequence = strids2ids(tokens)
            if len(sequence) == 0:
                yield None
                continue
            if self.add_bos:
                sequence.insert(0, self.bos_id)
            if self.add_eos:
                sequence.append(self.eos_id)
            yield sequence


def create_sequence_readers(sources: List[str], targets: List[str],
                            vocab_sources: List[vocab.Vocab],
                            vocab_targets: List[vocab.Vocab]) -> Tuple[List[SequenceReader], List[SequenceReader]]:
    """
    Create source readers with EOS and target readers with BOS.

    :param sources: The file names of source data and factors.
    :param targets: The file name of the target data and factors.
    :param vocab_sources: The source vocabularies.
    :param vocab_targets: The target vocabularies.
    :return: The source sequence readers and the target reader.
    """
    source_sequence_readers = [SequenceReader(source, vocab, add_eos=True) for source, vocab in
                                zip(sources, vocab_sources)]
    target_sequence_readers = [SequenceReader(target, vocab, add_bos=True) for target, vocab in
                                zip(targets, vocab_targets)]
    return source_sequence_readers, target_sequence_readers


def parallel_iter(source_iterables: Sequence[Iterable[Optional[Any]]],
                  target_iterables: Sequence[Iterable[Optional[Any]]],
                  skip_blanks: bool = True,
                  check_token_parallel: bool = True):
    """
    Creates iterators over parallel iterables by calling iter() on the iterables
    and chaining to parallel_iterate(). The purpose of the separation is to allow
    the caller to save iterator state between calls, if desired.

    :param source_iterables: A list of source iterables.
    :param target_iterables: A target iterable.
    :param skip_blanks: Whether to skip empty target lines.
    :param check_token_parallel: Whether to check if the tokens are parallel or not.
    :return: Iterators over sources and target.
    """
    source_iterators = [iter(s) for s in source_iterables]
    target_iterators = [iter(t) for t in target_iterables]
    return parallel_iterate(source_iterators, target_iterators, skip_blanks, check_token_parallel)


def parallel_iterate(source_iterators: Sequence[Iterator[Optional[Any]]],
                     target_iterators: Sequence[Iterator[Optional[Any]]],
                     skip_blanks: bool = True,
                     check_token_parallel: bool = True):
    """
    Yields parallel source(s), target sequences from iterables.
    Checks for token parallelism in source sequences.
    Skips pairs where element in at least one iterable is None.
    Checks that all iterables have the same number of elements.
    Can optionally continue from an already-begun iterator.

    :param source_iterators: A list of source iterators.
    :param target_iterators: A list of source iterators.
    :param skip_blanks: Whether to skip empty target lines.
    :param check_token_parallel: Whether to check if the tokens are parallel or not.
    :return: Iterators over sources and target.
    """
    num_skipped = 0
    while True:
        try:
            sources = [next(source_iter) for source_iter in source_iterators]
            targets = [next(target_iter) for target_iter in target_iterators]
        except StopIteration:
            break
        if skip_blanks and (any((s is None for s in sources)) or any((t is None for t in targets))):
            num_skipped += 1
            continue
        if check_token_parallel:
            check_condition(are_none(sources) or are_token_parallel(sources),
                            "Source sequences are not token-parallel: %s" % (str(sources)))
            check_condition(are_none(targets) or are_token_parallel(targets),
                            "Target sequences are not token-parallel: %s" % (str(targets)))
        yield sources, targets

    if num_skipped > 0:
        logger.warning("Parallel reading of sequences skipped %d elements", num_skipped)

    check_condition(
        all(next(cast(Iterator, s), None) is None for s in source_iterators) and \
        all(next(cast(Iterator, t), None) is None for t in target_iterators),
        "Different number of lines in source(s) and target(s) iterables.")
- - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - -



