def get_training_data_iters()

in sockeye/data_io.py [0:0]


def get_training_data_iters(sources: List[str],
                            targets: List[str],
                            validation_sources: List[str],
                            validation_targets: List[str],
                            source_vocabs: List[vocab.Vocab],
                            target_vocabs: List[vocab.Vocab],
                            source_vocab_paths: List[Optional[str]],
                            target_vocab_paths: List[Optional[str]],
                            shared_vocab: bool,
                            batch_size: int,
                            batch_type: str,
                            max_seq_len_source: int,
                            max_seq_len_target: int,
                            bucketing: bool,
                            bucket_width: int,
                            bucket_scaling: bool = True,
                            allow_empty: bool = False,
                            batch_sentences_multiple_of: int = 1,
                            batch_num_devices: int = 1,
                            permute: bool = True) -> Tuple['BaseParallelSampleIter',
                                                                           Optional['BaseParallelSampleIter'],
                                                                           'DataConfig', 'DataInfo']:
    """
    Returns data iterators for training and validation data.

    :param sources: Path to source training data (with optional factor data paths).
    :param targets: Path to target training data (with optional factor data paths).
    :param validation_sources: Path to source validation data (with optional factor data paths).
    :param validation_targets: Path to target validation data (with optional factor data paths).
    :param source_vocabs: Source vocabulary and optional factor vocabularies.
    :param target_vocabs: Target vocabulary and optional factor vocabularies.
    :param source_vocab_paths: Path to source vocabularies.
    :param target_vocab_paths: Path to target vocabularies.
    :param shared_vocab: Whether the vocabularies are shared.
    :param batch_size: Batch size.
    :param batch_type: Method for sizing batches.
    :param batch_num_devices: Number of devices batches will be parallelized across.
    :param max_seq_len_source: Maximum source sequence length.
    :param max_seq_len_target: Maximum target sequence length.
    :param bucketing: Whether to use bucketing.
    :param bucket_width: Size of buckets.
    :param bucket_scaling: Scale bucket steps based on source/target length ratio.
    :param allow_empty: Unless True if no sentences are below or equal to the maximum length an exception is raised.
    :param batch_sentences_multiple_of: Round the number of sentences in each
        bucket's batch to a multiple of this value (word-based batching only).
    :param permute: Randomly shuffle the parallel data.

    :return: Tuple of (training data iterator, validation data iterator, data config).
    """
    logger.info("===============================")
    logger.info("Creating training data iterator")
    logger.info("===============================")
    # Pass 1: get target/source length ratios.
    length_statistics = analyze_sequence_lengths(sources, targets, source_vocabs, target_vocabs,
                                                 max_seq_len_source, max_seq_len_target)

    if not allow_empty:
        check_condition(length_statistics.num_sents > 0,
                        "No training sequences found with length smaller or equal than the maximum sequence length."
                        "Consider increasing %s" % C.TRAINING_ARG_MAX_SEQ_LEN)

    # define buckets
    buckets = define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, bucket_scaling,
                                      length_statistics.length_ratio_mean) if bucketing else [(max_seq_len_source,
                                                                                               max_seq_len_target)]

    sources_sentences, targets_sentences = create_sequence_readers(sources, targets, source_vocabs, target_vocabs)

    # Pass 2: Get data statistics and determine the number of data points for each bucket.
    data_statistics = get_data_statistics(sources_sentences, targets_sentences, buckets,
                                          length_statistics.length_ratio_mean, length_statistics.length_ratio_std,
                                          source_vocabs, target_vocabs)

    bucket_batch_sizes = define_bucket_batch_sizes(buckets,
                                                   batch_size,
                                                   batch_type,
                                                   data_statistics.average_len_target_per_bucket,
                                                   batch_sentences_multiple_of,
                                                   batch_num_devices)

    data_statistics.log(bucket_batch_sizes)

    # Pass 3: Load the data into memory and return the iterator.
    data_loader = RawParallelDatasetLoader(buckets=buckets,
                                           eos_id=C.EOS_ID,
                                           pad_id=C.PAD_ID)

    training_data = data_loader.load(sources_sentences, targets_sentences,
                                     data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes)

    data_info = DataInfo(sources=sources,
                         targets=targets,
                         source_vocabs=source_vocab_paths,
                         target_vocabs=target_vocab_paths,
                         shared_vocab=shared_vocab,
                         num_shards=1)

    config_data = DataConfig(data_statistics=data_statistics,
                             max_seq_len_source=max_seq_len_source,
                             max_seq_len_target=max_seq_len_target,
                             num_source_factors=len(sources),
                             num_target_factors=len(targets))

    train_iter = ParallelSampleIter(data=training_data,
                                    buckets=buckets,
                                    batch_size=batch_size,
                                    bucket_batch_sizes=bucket_batch_sizes,
                                    num_source_factors=len(sources),
                                    num_target_factors=len(targets),
                                    permute=permute)

    validation_iter = get_validation_data_iter(data_loader=data_loader,
                                               validation_sources=validation_sources,
                                               validation_targets=validation_targets,
                                               buckets=buckets,
                                               bucket_batch_sizes=bucket_batch_sizes,
                                               source_vocabs=source_vocabs,
                                               target_vocabs=target_vocabs,
                                               max_seq_len_source=max_seq_len_source,
                                               max_seq_len_target=max_seq_len_target,
                                               batch_size=batch_size,
                                               permute=permute)

    return train_iter, validation_iter, config_data, data_info