in sockeye/data_io.py [0:0]
def get_prepared_data_iters(prepared_data_dir: str,
validation_sources: List[str],
validation_targets: List[str],
shared_vocab: bool,
batch_size: int,
batch_type: str,
batch_sentences_multiple_of: int = 1,
batch_num_devices: int = 1,
permute: bool = True) -> Tuple['BaseParallelSampleIter',
'BaseParallelSampleIter',
'DataConfig', List[vocab.Vocab], List[vocab.Vocab]]:
logger.info("===============================")
logger.info("Creating training data iterator")
logger.info("===============================")
version_file = os.path.join(prepared_data_dir, C.PREPARED_DATA_VERSION_FILE)
with open(version_file) as version_in:
version = int(version_in.read())
check_condition(version == C.PREPARED_DATA_VERSION,
"The dataset %s was written in an old and incompatible format. Please rerun data "
"preparation with a current version of Sockeye." % prepared_data_dir)
info_file = os.path.join(prepared_data_dir, C.DATA_INFO)
check_condition(os.path.exists(info_file),
"Could not find data info %s. Are you sure %s is a directory created with "
"python -m sockeye.prepare_data?" % (info_file, prepared_data_dir))
data_info = cast(DataInfo, DataInfo.load(info_file))
config_file = os.path.join(prepared_data_dir, C.DATA_CONFIG)
check_condition(os.path.exists(config_file),
"Could not find data config %s. Are you sure %s is a directory created with "
"python -m sockeye.prepare_data?" % (config_file, prepared_data_dir))
config_data = cast(DataConfig, DataConfig.load(config_file))
shard_fnames = [os.path.join(prepared_data_dir,
C.SHARD_NAME % shard_idx) for shard_idx in range(data_info.num_shards)]
for shard_fname in shard_fnames:
check_condition(os.path.exists(shard_fname), "Shard %s does not exist." % shard_fname)
check_condition(shared_vocab == data_info.shared_vocab, "Shared vocabulary settings need to match these "
"of the prepared data (e.g. for weight tying). "
"Specify or omit %s consistently when training "
"and preparing the data." % C.VOCAB_ARG_SHARED_VOCAB)
source_vocabs = vocab.load_source_vocabs(prepared_data_dir)
target_vocabs = vocab.load_target_vocabs(prepared_data_dir)
check_condition(len(source_vocabs) == len(data_info.sources),
"Wrong number of source vocabularies. Found %d, need %d." % (len(source_vocabs),
len(data_info.sources)))
check_condition(len(target_vocabs) == len(data_info.targets),
"Wrong number of target vocabularies. Found %d, need %d." % (len(target_vocabs),
len(data_info.targets)))
buckets = config_data.data_statistics.buckets
max_seq_len_source = config_data.max_seq_len_source
max_seq_len_target = config_data.max_seq_len_target
bucket_batch_sizes = define_bucket_batch_sizes(buckets,
batch_size,
batch_type,
config_data.data_statistics.average_len_target_per_bucket,
batch_sentences_multiple_of,
batch_num_devices)
config_data.data_statistics.log(bucket_batch_sizes)
train_iter = ShardedParallelSampleIter(shard_fnames,
buckets,
batch_size,
bucket_batch_sizes,
num_source_factors=len(data_info.sources),
num_target_factors=len(data_info.targets),
permute=permute)
data_loader = RawParallelDatasetLoader(buckets=buckets,
eos_id=C.EOS_ID,
pad_id=C.PAD_ID)
validation_iter = get_validation_data_iter(data_loader=data_loader,
validation_sources=validation_sources,
validation_targets=validation_targets,
buckets=buckets,
bucket_batch_sizes=bucket_batch_sizes,
source_vocabs=source_vocabs,
target_vocabs=target_vocabs,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
batch_size=batch_size,
permute=permute)
return train_iter, validation_iter, config_data, source_vocabs, target_vocabs