in horovod/spark/common/util.py [0:0]
def _get_or_create_dataset(key, store, df, feature_columns, label_columns,
validation, sample_weight_col, compress_sparse,
num_partitions, num_processes, verbose):
with _training_cache.lock:
if _training_cache.is_cached(key, store):
dataset_idx = _training_cache.get_dataset(key)
train_rows, val_rows, metadata, avg_row_size = _training_cache.get_dataset_properties(dataset_idx)
train_data_path = store.get_train_data_path(dataset_idx)
val_data_path = store.get_val_data_path(dataset_idx)
if verbose:
print('using cached dataframes for key: {}'.format(key))
print('train_data_path={}'.format(train_data_path))
print('train_rows={}'.format(train_rows))
print('val_data_path={}'.format(val_data_path))
print('val_rows={}'.format(val_rows))
else:
dataset_idx = _training_cache.next_dataset_index(key)
train_data_path = store.get_train_data_path(dataset_idx)
val_data_path = store.get_val_data_path(dataset_idx)
if verbose:
print('writing dataframes')
print('train_data_path={}'.format(train_data_path))
print('val_data_path={}'.format(val_data_path))
schema_cols = feature_columns + label_columns
if sample_weight_col:
schema_cols.append(sample_weight_col)
if isinstance(validation, str):
schema_cols.append(validation)
df = df[schema_cols]
metadata = None
if _has_vector_column(df):
if compress_sparse:
metadata = _get_metadata(df)
to_petastorm = to_petastorm_fn(schema_cols, metadata)
df = df.rdd.map(to_petastorm).toDF()
train_df, val_df, validation_ratio = _train_val_split(df, validation)
train_partitions = max(int(num_partitions * (1.0 - validation_ratio)),
num_processes)
if verbose:
print('train_partitions={}'.format(train_partitions))
train_df \
.coalesce(train_partitions) \
.write \
.mode('overwrite') \
.parquet(train_data_path)
if val_df:
val_partitions = max(int(num_partitions * validation_ratio),
num_processes)
if verbose:
print('val_partitions={}'.format(val_partitions))
val_df \
.coalesce(val_partitions) \
.write \
.mode('overwrite') \
.parquet(val_data_path)
train_rows, val_rows, pq_metadata, avg_row_size = get_simple_meta_from_parquet(
store, label_columns, feature_columns, sample_weight_col, dataset_idx)
if verbose:
print('train_rows={}'.format(train_rows))
if val_df:
if val_rows == 0:
raise ValueError(
'Validation DataFrame does not any samples with validation param {}'
.format(validation))
if verbose:
print('val_rows={}'.format(val_rows))
metadata = metadata or pq_metadata
_training_cache.set_dataset_properties(
dataset_idx, (train_rows, val_rows, metadata, avg_row_size))
return dataset_idx