def prepare_data()

in sockeye/data_io_pt.py [0:0]


def prepare_data(source_fnames: List[str],
                 target_fnames: List[str],
                 source_vocabs: List[vocab.Vocab],
                 target_vocabs: List[vocab.Vocab],
                 source_vocab_paths: List[Optional[str]],
                 target_vocab_paths: List[Optional[str]],
                 shared_vocab: bool,
                 max_seq_len_source: int,
                 max_seq_len_target: int,
                 bucketing: bool,
                 bucket_width: int,
                 num_shards: int,
                 output_prefix: str,
                 bucket_scaling: bool = True,
                 keep_tmp_shard_files: bool = False,
                 pool: multiprocessing.pool.Pool = None,
                 shards: List[Tuple[Tuple[str, ...], Tuple[str, ...]]] = None):
    """
    :param shards: List of num_shards shards of parallel source and target tuples which in turn contain tuples to shard data factor file paths.
    """
    logger.info("Preparing data.")
    # write vocabularies to data folder
    vocab.save_source_vocabs(source_vocabs, output_prefix)
    vocab.save_target_vocabs(target_vocabs, output_prefix)

    # Get target/source length ratios.
    stats_args = ((source_path, target_path, source_vocabs, target_vocabs, max_seq_len_source, max_seq_len_target)
                  for source_path, target_path in shards)
    length_stats = pool.starmap(analyze_sequence_lengths, stats_args)
    shards_num_sents = [stat.num_sents for stat in length_stats]
    shards_mean = [stat.length_ratio_mean for stat in length_stats]
    shards_std = [stat.length_ratio_std for stat in length_stats]
    length_ratio_mean = combine_means(shards_mean, shards_num_sents)
    length_ratio_std = combine_stds(shards_std, shards_mean, shards_num_sents)
    length_statistics = LengthStatistics(sum(shards_num_sents), length_ratio_mean, length_ratio_std)

    check_condition(length_statistics.num_sents > 0,
                    "No training sequences found with length smaller or equal than the maximum sequence length."
                    "Consider increasing %s" % C.TRAINING_ARG_MAX_SEQ_LEN)

    # define buckets
    buckets = define_parallel_buckets(max_seq_len_source, max_seq_len_target, bucket_width, bucket_scaling,
                                      length_statistics.length_ratio_mean) if bucketing else [(max_seq_len_source,
                                                                                               max_seq_len_target)]
    logger.info("Buckets: %s", buckets)

    # Map sentences to ids, assign to buckets, compute shard statistics and convert each shard to serialized tensors
    data_loader = RawParallelDatasetLoader(buckets=buckets,
                                           eos_id=C.EOS_ID,
                                           pad_id=C.PAD_ID)


    # Process shards in parallel
    args = ((shard_idx, data_loader, shard_sources, shard_targets, source_vocabs, target_vocabs,
             length_statistics.length_ratio_mean, length_statistics.length_ratio_std, buckets, output_prefix,
             keep_tmp_shard_files) for shard_idx, (shard_sources, shard_targets) in enumerate(shards))
    per_shard_statistics = pool.starmap(save_shard, args)

    # Combine per shard statistics to obtain global statistics
    shard_average_len = [shard_stats.average_len_target_per_bucket for shard_stats in per_shard_statistics]
    shard_num_sents = [shard_stats.num_sents_per_bucket for shard_stats in per_shard_statistics]
    num_sents_per_bucket = [sum(n) for n in zip(*shard_num_sents)]
    average_len_target_per_bucket = [] # type: List[Optional[float]]
    for num_sents_bucket, average_len_bucket in zip(zip(*shard_num_sents), zip(*shard_average_len)):
        if all(avg is None for avg in average_len_bucket):
            average_len_target_per_bucket.append(None)
        else:
            average_len_target_per_bucket.append(combine_means(average_len_bucket, shards_num_sents))

    shard_length_ratios = [shard_stats.length_ratio_stats_per_bucket for shard_stats in per_shard_statistics]
    length_ratio_stats_per_bucket = [] # type: Optional[List[Tuple[Optional[float], Optional[float]]]]
    for num_sents_bucket, len_ratios_bucket in zip(zip(*shard_num_sents), zip(*shard_length_ratios)):
        if all(all(x is None for x in ratio) for ratio in len_ratios_bucket):
            length_ratio_stats_per_bucket.append((None, None))
        else:
            shards_mean = [ratio[0] for ratio in len_ratios_bucket]
            ratio_mean = combine_means(shards_mean, num_sents_bucket)
            ratio_std = combine_stds([ratio[1] for ratio in len_ratios_bucket], shards_mean, num_sents_bucket)
            length_ratio_stats_per_bucket.append((ratio_mean, ratio_std))
    data_statistics = DataStatistics(
        num_sents=sum(shards_num_sents),
        num_discarded=sum(shard_stats.num_discarded for shard_stats in per_shard_statistics),
        num_tokens_source=sum(shard_stats.num_tokens_source for shard_stats in per_shard_statistics),
        num_tokens_target=sum(shard_stats.num_tokens_target for shard_stats in per_shard_statistics),
        num_unks_source=sum(shard_stats.num_unks_source for shard_stats in per_shard_statistics),
        num_unks_target=sum(shard_stats.num_unks_target for shard_stats in per_shard_statistics),
        max_observed_len_source=max(shard_stats.max_observed_len_source for shard_stats in per_shard_statistics),
        max_observed_len_target=max(shard_stats.max_observed_len_target for shard_stats in per_shard_statistics),
        size_vocab_source=per_shard_statistics[0].size_vocab_source,
        size_vocab_target=per_shard_statistics[0].size_vocab_target,
        length_ratio_mean=length_ratio_mean,
        length_ratio_std=length_ratio_std,
        buckets=per_shard_statistics[0].buckets,
        num_sents_per_bucket=num_sents_per_bucket,
        average_len_target_per_bucket=average_len_target_per_bucket,
        length_ratio_stats_per_bucket=length_ratio_stats_per_bucket)
    data_statistics.log()

    data_info = DataInfo(sources=[os.path.abspath(fname) for fname in source_fnames],
                         targets=[os.path.abspath(fname) for fname in target_fnames],
                         source_vocabs=source_vocab_paths,
                         target_vocabs=target_vocab_paths,
                         shared_vocab=shared_vocab,
                         num_shards=num_shards)
    data_info_fname = os.path.join(output_prefix, C.DATA_INFO)
    logger.info("Writing data info to '%s'", data_info_fname)
    data_info.save(data_info_fname)

    config_data = DataConfig(data_statistics=data_statistics,
                             max_seq_len_source=max_seq_len_source,
                             max_seq_len_target=max_seq_len_target,
                             num_source_factors=len(source_fnames),
                             num_target_factors=len(target_fnames))
    config_data_fname = os.path.join(output_prefix, C.DATA_CONFIG)
    logger.info("Writing data config to '%s'", config_data_fname)
    config_data.save(config_data_fname)

    version_file = os.path.join(output_prefix, C.PREPARED_DATA_VERSION_FILE)

    with open(version_file, "w") as version_out:
        version_out.write(str(C.PREPARED_DATA_VERSION))