in horovod/spark/common/util.py [0:0]
def _get_col_info(df):
"""
Infer the type and shape of all the columns.
NOTE: This function processes the entire DataFrame, and can therefore be very expensive to run.
TODO(travis): Only run this if user sets compress_sparse param, otherwise convert all to Array.
"""
def get_meta(row):
row_dict = row.asDict()
row_schema = []
for col_name, data_col in row_dict.items():
dtype = type(data_col)
if isinstance(data_col, DenseVector):
# shape and size of dense vector are the same
shape = size = data_col.array.shape[0]
elif isinstance(data_col, SparseVector):
# shape is the total size of vector
shape = data_col.size
# size is the number of nonzero elements in the sparse vector
size = data_col.indices.shape[0]
elif isinstance(data_col, list):
shape = size = len(data_col)
else:
shape = size = 1
row_schema.append((col_name, ({dtype}, {shape}, {size})))
return row_schema
def merge(x, y):
x_dtypes, x_shapes, x_sizes = x
y_dtypes, y_shapes, y_sizes = y
dtypes = x_dtypes | y_dtypes
shapes = x_shapes | y_shapes
sizes = x_sizes | y_sizes
return dtypes, {min(shapes), max(shapes)}, {min(sizes), max(sizes)}
raw_col_info_list = df.rdd.flatMap(get_meta).reduceByKey(merge).collect()
all_col_types = {}
col_shapes = {}
col_max_sizes = {}
for col_info in raw_col_info_list:
col_name, col_meta = col_info
dtypes, shapes, sizes = col_meta
all_col_types[col_name] = dtypes
col_shapes[col_name] = shapes
col_max_sizes[col_name] = sizes
for col in df.schema.names:
# All rows in every column must have the same shape
shape_set = col_shapes[col]
if len(shape_set) != 1:
raise ValueError(
'Column {col} does not have uniform shape. '
'shape set: {shapes_set}'.format(col=col, shapes_set=shape_set))
col_shapes[col] = shape_set.pop()
# All rows in every column must have the same size unless they have SparseVectors
sizes = col_max_sizes[col]
if len(sizes) > 1 and not (SparseVector in all_col_types[col]):
raise ValueError(
'Rows of column {col} have varying sizes. This is only allowed if datatype is '
'SparseVector or a mix of Sparse and DenseVector.'.format(col=col))
col_max_sizes[col] = max(sizes)
return all_col_types, col_shapes, col_max_sizes