in sockeye/data_io.py [0:0]
def get_training_data_iters(sources: List[str],
targets: List[str],
validation_sources: List[str],
validation_targets: List[str],
source_vocabs: List[vocab.Vocab],
target_vocabs: List[vocab.Vocab],
source_vocab_paths: List[Optional[str]],
target_vocab_paths: List[Optional[str]],
shared_vocab: bool,
batch_size: int,
batch_type: str,
max_seq_len_source: int,
max_seq_len_target: int,
bucketing: bool,
bucket_width: int,
bucket_scaling: bool = True,
allow_empty: bool = False,
batch_sentences_multiple_of: int = 1,
batch_num_devices: int = 1,
permute: bool = True) -> Tuple['BaseParallelSampleIter',
Optional['BaseParallelSampleIter'],
'DataConfig', 'DataInfo']:
"""
Returns data iterators for training and validation data.
:param sources: Path to source training data (with optional factor data paths).
:param targets: Path to target training data (with optional factor data paths).
:param validation_sources: Path to source validation data (with optional factor data paths).
:param validation_targets: Path to target validation data (with optional factor data paths).
:param source_vocabs: Source vocabulary and optional factor vocabularies.
:param target_vocabs: Target vocabulary and optional factor vocabularies.
:param source_vocab_paths: Path to source vocabularies.
:param target_vocab_paths: Path to target vocabularies.
:param shared_vocab: Whether the vocabularies are shared.
:param batch_size: Batch size.
:param batch_type: Method for sizing batches.
:param batch_num_devices: Number of devices batches will be parallelized across.
:param max_seq_len_source: Maximum source sequence length.
:param max_seq_len_target: Maximum target sequence length.
:param bucketing: Whether to use bucketing.
:param bucket_width: Size of buckets.
:param bucket_scaling: Scale bucket steps based on source/target length ratio.
:param allow_empty: Unless True if no sentences are below or equal to the maximum length an exception is raised.
:param batch_sentences_multiple_of: Round the number of sentences in each
bucket's batch to a multiple of this value (word-based batching only).
:param permute: Randomly shuffle the parallel data.
:return: Tuple of (training data iterator, validation data iterator, data config).
"""
logger.info("===============================")
logger.info("Creating training data iterator")
logger.info("===============================")
# Pass 1: get target/source length ratios.
length_statistics = analyze_sequence_lengths(sources, targets, source_vocabs, target_vocabs,
max_seq_len_source, max_seq_len_target)
if not allow_empty:
check_condition(length_statistics.num_sents > 0,
"No training sequences found with length smaller or equal than the maximum sequence length."
"Consider increasing %s" % C.TRAINING_ARG_MAX_SEQ_LEN)
# define buckets
buckets = define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, bucket_scaling,
length_statistics.length_ratio_mean) if bucketing else [(max_seq_len_source,
max_seq_len_target)]
sources_sentences, targets_sentences = create_sequence_readers(sources, targets, source_vocabs, target_vocabs)
# Pass 2: Get data statistics and determine the number of data points for each bucket.
data_statistics = get_data_statistics(sources_sentences, targets_sentences, buckets,
length_statistics.length_ratio_mean, length_statistics.length_ratio_std,
source_vocabs, target_vocabs)
bucket_batch_sizes = define_bucket_batch_sizes(buckets,
batch_size,
batch_type,
data_statistics.average_len_target_per_bucket,
batch_sentences_multiple_of,
batch_num_devices)
data_statistics.log(bucket_batch_sizes)
# Pass 3: Load the data into memory and return the iterator.
data_loader = RawParallelDatasetLoader(buckets=buckets,
eos_id=C.EOS_ID,
pad_id=C.PAD_ID)
training_data = data_loader.load(sources_sentences, targets_sentences,
data_statistics.num_sents_per_bucket).fill_up(bucket_batch_sizes)
data_info = DataInfo(sources=sources,
targets=targets,
source_vocabs=source_vocab_paths,
target_vocabs=target_vocab_paths,
shared_vocab=shared_vocab,
num_shards=1)
config_data = DataConfig(data_statistics=data_statistics,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
num_source_factors=len(sources),
num_target_factors=len(targets))
train_iter = ParallelSampleIter(data=training_data,
buckets=buckets,
batch_size=batch_size,
bucket_batch_sizes=bucket_batch_sizes,
num_source_factors=len(sources),
num_target_factors=len(targets),
permute=permute)
validation_iter = get_validation_data_iter(data_loader=data_loader,
validation_sources=validation_sources,
validation_targets=validation_targets,
buckets=buckets,
bucket_batch_sizes=bucket_batch_sizes,
source_vocabs=source_vocabs,
target_vocabs=target_vocabs,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
batch_size=batch_size,
permute=permute)
return train_iter, validation_iter, config_data, data_info