in sockeye/train.py [0:0]
def create_data_iters_and_vocabs(args: argparse.Namespace,
max_seq_len_source: int,
max_seq_len_target: int,
shared_vocab: bool,
resume_training: bool,
output_folder: str) -> Tuple['data_io.BaseParallelSampleIter',
'data_io.BaseParallelSampleIter',
'data_io.DataConfig',
List[vocab.Vocab], List[vocab.Vocab]]:
"""
Create the data iterators and the vocabularies.
:param args: Arguments as returned by argparse.
:param max_seq_len_source: Source maximum sequence length.
:param max_seq_len_target: Target maximum sequence length.
:param shared_vocab: Whether to create a shared vocabulary.
:param resume_training: Whether to resume training.
:param output_folder: Output folder.
:return: The data iterators (train, validation, config_data) as well as the source and target vocabularies.
"""
num_words_source, num_words_target = args.num_words
num_words_source = num_words_source if num_words_source > 0 else None
num_words_target = num_words_target if num_words_target > 0 else None
word_min_count_source, word_min_count_target = args.word_min_count
batch_num_devices = 1 if args.use_cpu else sum(-di if di < 0 else 1 for di in args.device_ids)
validation_sources = [args.validation_source] + args.validation_source_factors
validation_sources = [str(os.path.abspath(source)) for source in validation_sources]
validation_targets = [args.validation_target] + args.validation_target_factors
validation_targets = [str(os.path.abspath(target)) for target in validation_targets]
if args.horovod:
horovod_data_error_msg = "Horovod training requires prepared training data. Use `python -m " \
"sockeye.prepare_data` and specify with %s" % C.TRAINING_ARG_PREPARED_DATA
check_condition(args.prepared_data is not None, horovod_data_error_msg)
either_raw_or_prepared_error_msg = "Either specify a raw training corpus with %s and %s or a preprocessed corpus " \
"with %s." % (C.TRAINING_ARG_SOURCE,
C.TRAINING_ARG_TARGET,
C.TRAINING_ARG_PREPARED_DATA)
if args.prepared_data is not None:
utils.check_condition(args.source is None and args.target is None, either_raw_or_prepared_error_msg)
if not resume_training:
utils.check_condition(args.source_vocab is None and args.target_vocab is None,
"You are using a prepared data folder, which is tied to a vocabulary. "
"To change it you need to rerun data preparation with a different vocabulary.")
train_iter, validation_iter, data_config, source_vocabs, target_vocabs = data_io.get_prepared_data_iters(
prepared_data_dir=args.prepared_data,
validation_sources=validation_sources,
validation_targets=validation_targets,
shared_vocab=shared_vocab,
batch_size=args.batch_size,
batch_type=args.batch_type,
batch_num_devices=batch_num_devices,
batch_sentences_multiple_of=args.batch_sentences_multiple_of)
check_condition(all([combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]
for combine in args.source_factors_combine])
or len(source_vocabs) == len(args.source_factors_num_embed) + 1,
"Data was prepared with %d source factors, but only provided %d source factor dimensions." % (
len(source_vocabs), len(args.source_factors_num_embed) + 1))
check_condition(all([combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]
for combine in args.target_factors_combine])
or len(target_vocabs) == len(args.target_factors_num_embed) + 1,
"Data was prepared with %d target factors, but only provided %d target factor dimensions." % (
len(target_vocabs), len(args.target_factors_num_embed) + 1))
if resume_training:
# resuming training. Making sure the vocabs in the model and in the prepared data match up
model_source_vocabs = vocab.load_source_vocabs(output_folder)
for i, (v, mv) in enumerate(zip(source_vocabs, model_source_vocabs)):
utils.check_condition(vocab.are_identical(v, mv),
"Prepared data and resumed model source vocab %d do not match." % i)
model_target_vocabs = vocab.load_target_vocabs(output_folder)
for i, (v, mv) in enumerate(zip(target_vocabs, model_target_vocabs)):
utils.check_condition(vocab.are_identical(v, mv),
"Prepared data and resumed model target vocab %d do not match." % i)
check_condition(data_config.num_source_factors == len(validation_sources),
'Training and validation data must have the same number of source factors,'
' but found %d and %d.' % (
data_config.num_source_factors, len(validation_sources)))
check_condition(data_config.num_target_factors == len(validation_targets),
'Training and validation data must have the same number of target factors,'
' but found %d and %d.' % (
data_config.num_target_factors, len(validation_targets)))
return train_iter, validation_iter, data_config, source_vocabs, target_vocabs
else:
utils.check_condition(args.prepared_data is None and args.source is not None and args.target is not None,
either_raw_or_prepared_error_msg)
if resume_training:
# Load the existing vocabs created when starting the training run.
source_vocabs = vocab.load_source_vocabs(output_folder)
target_vocabs = vocab.load_target_vocabs(output_folder)
# Recover the vocabulary path from the data info file:
data_info = cast(data_io.DataInfo, Config.load(os.path.join(output_folder, C.DATA_INFO)))
source_vocab_paths = data_info.source_vocabs
target_vocab_paths = data_info.target_vocabs
else:
# Load or create vocabs
source_factor_vocab_paths = [args.source_factor_vocabs[i] if i < len(args.source_factor_vocabs)
else None for i in range(len(args.source_factors))]
source_vocab_paths = [args.source_vocab] + source_factor_vocab_paths
target_factor_vocab_paths = [args.target_factor_vocabs[i] if i < len(args.target_factor_vocabs)
else None for i in range(len(args.target_factors))]
target_vocab_paths = [args.target_vocab] + target_factor_vocab_paths
source_vocabs, target_vocabs = vocab.load_or_create_vocabs(
shard_source_paths=[[args.source] + args.source_factors],
shard_target_paths=[[args.target] + args.target_factors],
source_vocab_paths=source_vocab_paths,
source_factor_vocab_same_as_source=args.source_factors_share_embedding,
target_vocab_paths=target_vocab_paths,
target_factor_vocab_same_as_target=args.target_factors_share_embedding,
shared_vocab=shared_vocab,
num_words_source=num_words_source,
num_words_target=num_words_target,
word_min_count_source=word_min_count_source,
word_min_count_target=word_min_count_target,
pad_to_multiple_of=args.pad_vocab_to_multiple_of)
check_condition(all([combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]
for combine in args.source_factors_combine])
or len(args.source_factors) == len(args.source_factors_num_embed),
"Number of source factor data (%d) differs from provided source factor dimensions (%d)" % (
len(args.source_factors), len(args.source_factors_num_embed)))
check_condition(all([combine in [C.FACTORS_COMBINE_SUM, C.FACTORS_COMBINE_AVERAGE]
for combine in args.target_factors_combine])
or len(args.target_factors) == len(args.target_factors_num_embed),
"Number of target factor data (%d) differs from provided source factor dimensions (%d)" % (
len(args.target_factors), len(args.target_factors_num_embed)))
sources = [args.source] + args.source_factors
sources = [str(os.path.abspath(s)) for s in sources]
targets = [args.target] + args.target_factors
targets = [str(os.path.abspath(t)) for t in targets]
check_condition(len(sources) == len(validation_sources),
'Training and validation data must have the same number of source factors, '
'but found %d and %d.' % (len(source_vocabs), len(validation_sources)))
check_condition(len(targets) == len(validation_targets),
'Training and validation data must have the same number of target factors, '
'but found %d and %d.' % (len(source_vocabs), len(validation_sources)))
train_iter, validation_iter, config_data, data_info = data_io.get_training_data_iters(
sources=sources,
targets=targets,
validation_sources=validation_sources,
validation_targets=validation_targets,
source_vocabs=source_vocabs,
target_vocabs=target_vocabs,
source_vocab_paths=source_vocab_paths,
target_vocab_paths=target_vocab_paths,
shared_vocab=shared_vocab,
batch_size=args.batch_size,
batch_type=args.batch_type,
batch_num_devices=batch_num_devices,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
bucketing=not args.no_bucketing,
bucket_width=args.bucket_width,
bucket_scaling=args.bucket_scaling,
batch_sentences_multiple_of=args.batch_sentences_multiple_of)
data_info_fname = os.path.join(output_folder, C.DATA_INFO)
logger.info("Writing data config to '%s'", data_info_fname)
data_info.save(data_info_fname)
return train_iter, validation_iter, config_data, source_vocabs, target_vocabs