def _get_or_create_dataset()

in horovod/spark/common/util.py [0:0]


def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
                           validation, sample_weight_col, compress_sparse,
                           num_partitions, num_processes, verbose):
    with _training_cache.lock:
        if _training_cache.is_cached(key, store):
            dataset_idx = _training_cache.get_dataset(key)
            train_rows, val_rows, metadata, avg_row_size = _training_cache.get_dataset_properties(dataset_idx)
            train_data_path = store.get_train_data_path(dataset_idx)
            val_data_path = store.get_val_data_path(dataset_idx)
            if verbose:
                print('using cached dataframes for key: {}'.format(key))
                print('train_data_path={}'.format(train_data_path))
                print('train_rows={}'.format(train_rows))
                print('val_data_path={}'.format(val_data_path))
                print('val_rows={}'.format(val_rows))
        else:
            dataset_idx = _training_cache.next_dataset_index(key)
            train_data_path = store.get_train_data_path(dataset_idx)
            val_data_path = store.get_val_data_path(dataset_idx)
            if verbose:
                print('writing dataframes')
                print('train_data_path={}'.format(train_data_path))
                print('val_data_path={}'.format(val_data_path))

            schema_cols = feature_columns + label_columns
            if sample_weight_col:
                schema_cols.append(sample_weight_col)
            if isinstance(validation, str):
                schema_cols.append(validation)
            df = df[schema_cols]

            metadata = None
            if _has_vector_column(df):
                if compress_sparse:
                    metadata = _get_metadata(df)
                to_petastorm = to_petastorm_fn(schema_cols, metadata)
                df = df.rdd.map(to_petastorm).toDF()

            train_df, val_df, validation_ratio = _train_val_split(df, validation)

            train_partitions = max(int(num_partitions * (1.0 - validation_ratio)),
                                   num_processes)
            if verbose:
                print('train_partitions={}'.format(train_partitions))

            train_df \
                .coalesce(train_partitions) \
                .write \
                .mode('overwrite') \
                .parquet(train_data_path)

            if val_df:
                val_partitions = max(int(num_partitions * validation_ratio),
                                     num_processes)
                if verbose:
                    print('val_partitions={}'.format(val_partitions))

                val_df \
                    .coalesce(val_partitions) \
                    .write \
                    .mode('overwrite') \
                    .parquet(val_data_path)

            train_rows, val_rows, pq_metadata, avg_row_size = get_simple_meta_from_parquet(
                store, label_columns, feature_columns, sample_weight_col, dataset_idx)

            if verbose:
                print('train_rows={}'.format(train_rows))
            if val_df:
                if val_rows == 0:
                    raise ValueError(
                        'Validation DataFrame does not any samples with validation param {}'
                        .format(validation))
                if verbose:
                    print('val_rows={}'.format(val_rows))

            metadata = metadata or pq_metadata
            _training_cache.set_dataset_properties(
                dataset_idx, (train_rows, val_rows, metadata, avg_row_size))
        return dataset_idx