in sockeye/data_io_pt.py [0:0]
def prepare_data(source_fnames: List[str],
target_fnames: List[str],
source_vocabs: List[vocab.Vocab],
target_vocabs: List[vocab.Vocab],
source_vocab_paths: List[Optional[str]],
target_vocab_paths: List[Optional[str]],
shared_vocab: bool,
max_seq_len_source: int,
max_seq_len_target: int,
bucketing: bool,
bucket_width: int,
num_shards: int,
output_prefix: str,
bucket_scaling: bool = True,
keep_tmp_shard_files: bool = False,
pool: multiprocessing.pool.Pool = None,
shards: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = None):
"""
:param shards: List of num_shards shards of parallel source and target tuples which in turn contain tuples to shard data factor file paths.
"""
logger.info("Preparing data.")
# write vocabularies to data folder
vocab.save_source_vocabs(source_vocabs, output_prefix)
vocab.save_target_vocabs(target_vocabs, output_prefix)
# Get target/source length ratios.
stats_args = ((source_path, target_path, source_vocabs, target_vocabs, max_seq_len_source, max_seq_len_target)
for source_path, target_path in shards)
length_stats = pool.starmap(analyze_sequence_lengths, stats_args)
shards_num_sents = [stat.num_sents for stat in length_stats]
shards_mean = [stat.length_ratio_mean for stat in length_stats]
shards_std = [stat.length_ratio_std for stat in length_stats]
length_ratio_mean = combine_means(shards_mean, shards_num_sents)
length_ratio_std = combine_stds(shards_std, shards_mean, shards_num_sents)
length_statistics = LengthStatistics(sum(shards_num_sents), length_ratio_mean, length_ratio_std)
check_condition(length_statistics.num_sents > 0,
"No training sequences found with length smaller or equal than the maximum sequence length."
"Consider increasing %s" % C.TRAINING_ARG_MAX_SEQ_LEN)
# define buckets
buckets = define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, bucket_scaling,
length_statistics.length_ratio_mean) if bucketing else [(max_seq_len_source,
max_seq_len_target)]
logger.info("Buckets: %s", buckets)
# Map sentences to ids, assign to buckets, compute shard statistics and convert each shard to serialized tensors
data_loader = RawParallelDatasetLoader(buckets=buckets,
eos_id=C.EOS_ID,
pad_id=C.PAD_ID)
# Process shards in parallel
args = ((shard_idx, data_loader, shard_sources, shard_targets, source_vocabs, target_vocabs,
length_statistics.length_ratio_mean, length_statistics.length_ratio_std, buckets, output_prefix,
keep_tmp_shard_files) for shard_idx, (shard_sources, shard_targets) in enumerate(shards))
per_shard_statistics = pool.starmap(save_shard, args)
# Combine per shard statistics to obtain global statistics
shard_average_len = [shard_stats.average_len_target_per_bucket for shard_stats in per_shard_statistics]
shard_num_sents = [shard_stats.num_sents_per_bucket for shard_stats in per_shard_statistics]
num_sents_per_bucket = [sum(n) for n in zip(*shard_num_sents)]
average_len_target_per_bucket = [] # type: List[Optional[float]]
for num_sents_bucket, average_len_bucket in zip(zip(*shard_num_sents), zip(*shard_average_len)):
if all(avg is None for avg in average_len_bucket):
average_len_target_per_bucket.append(None)
else:
average_len_target_per_bucket.append(combine_means(average_len_bucket, shards_num_sents))
shard_length_ratios = [shard_stats.length_ratio_stats_per_bucket for shard_stats in per_shard_statistics]
length_ratio_stats_per_bucket = [] # type: Optional[List[Tuple[Optional[float], Optional[float]]]]
for num_sents_bucket, len_ratios_bucket in zip(zip(*shard_num_sents), zip(*shard_length_ratios)):
if all(all(x is None for x in ratio) for ratio in len_ratios_bucket):
length_ratio_stats_per_bucket.append((None, None))
else:
shards_mean = [ratio[0] for ratio in len_ratios_bucket]
ratio_mean = combine_means(shards_mean, num_sents_bucket)
ratio_std = combine_stds([ratio[1] for ratio in len_ratios_bucket], shards_mean, num_sents_bucket)
length_ratio_stats_per_bucket.append((ratio_mean, ratio_std))
data_statistics = DataStatistics(
num_sents=sum(shards_num_sents),
num_discarded=sum(shard_stats.num_discarded for shard_stats in per_shard_statistics),
num_tokens_source=sum(shard_stats.num_tokens_source for shard_stats in per_shard_statistics),
num_tokens_target=sum(shard_stats.num_tokens_target for shard_stats in per_shard_statistics),
num_unks_source=sum(shard_stats.num_unks_source for shard_stats in per_shard_statistics),
num_unks_target=sum(shard_stats.num_unks_target for shard_stats in per_shard_statistics),
max_observed_len_source=max(shard_stats.max_observed_len_source for shard_stats in per_shard_statistics),
max_observed_len_target=max(shard_stats.max_observed_len_target for shard_stats in per_shard_statistics),
size_vocab_source=per_shard_statistics[0].size_vocab_source,
size_vocab_target=per_shard_statistics[0].size_vocab_target,
length_ratio_mean=length_ratio_mean,
length_ratio_std=length_ratio_std,
buckets=per_shard_statistics[0].buckets,
num_sents_per_bucket=num_sents_per_bucket,
average_len_target_per_bucket=average_len_target_per_bucket,
length_ratio_stats_per_bucket=length_ratio_stats_per_bucket)
data_statistics.log()
data_info = DataInfo(sources=[os.path.abspath(fname) for fname in source_fnames],
targets=[os.path.abspath(fname) for fname in target_fnames],
source_vocabs=source_vocab_paths,
target_vocabs=target_vocab_paths,
shared_vocab=shared_vocab,
num_shards=num_shards)
data_info_fname = os.path.join(output_prefix, C.DATA_INFO)
logger.info("Writing data info to '%s'", data_info_fname)
data_info.save(data_info_fname)
config_data = DataConfig(data_statistics=data_statistics,
max_seq_len_source=max_seq_len_source,
max_seq_len_target=max_seq_len_target,
num_source_factors=len(source_fnames),
num_target_factors=len(target_fnames))
config_data_fname = os.path.join(output_prefix, C.DATA_CONFIG)
logger.info("Writing data config to '%s'", config_data_fname)
config_data.save(config_data_fname)
version_file = os.path.join(output_prefix, C.PREPARED_DATA_VERSION_FILE)
with open(version_file, "w") as version_out:
version_out.write(str(C.PREPARED_DATA_VERSION))