def _get_col_info()

in horovod/spark/common/util.py [0:0]


def _get_col_info(df):
    """
    Infer the type and shape of all the columns.

    NOTE: This function processes the entire DataFrame, and can therefore be very expensive to run.

    TODO(travis): Only run this if user sets compress_sparse param, otherwise convert all to Array.
    """

    def get_meta(row):
        row_dict = row.asDict()
        row_schema = []
        for col_name, data_col in row_dict.items():
            dtype = type(data_col)
            if isinstance(data_col, DenseVector):
                # shape and size of dense vector are the same
                shape = size = data_col.array.shape[0]
            elif isinstance(data_col, SparseVector):
                # shape is the total size of vector
                shape = data_col.size
                # size is the number of nonzero elements in the sparse vector
                size = data_col.indices.shape[0]
            elif isinstance(data_col, list):
                shape = size = len(data_col)
            else:
                shape = size = 1
            row_schema.append((col_name, ({dtype}, {shape}, {size})))
        return row_schema

    def merge(x, y):
        x_dtypes, x_shapes, x_sizes = x
        y_dtypes, y_shapes, y_sizes = y
        dtypes = x_dtypes | y_dtypes
        shapes = x_shapes | y_shapes
        sizes = x_sizes | y_sizes
        return dtypes, {min(shapes), max(shapes)}, {min(sizes), max(sizes)}

    raw_col_info_list = df.rdd.flatMap(get_meta).reduceByKey(merge).collect()

    all_col_types = {}
    col_shapes = {}
    col_max_sizes = {}

    for col_info in raw_col_info_list:
        col_name, col_meta = col_info
        dtypes, shapes, sizes = col_meta

        all_col_types[col_name] = dtypes
        col_shapes[col_name] = shapes
        col_max_sizes[col_name] = sizes

    for col in df.schema.names:
        # All rows in every column must have the same shape
        shape_set = col_shapes[col]
        if len(shape_set) != 1:
            raise ValueError(
                'Column {col} does not have uniform shape. '
                'shape set: {shapes_set}'.format(col=col, shapes_set=shape_set))
        col_shapes[col] = shape_set.pop()

        # All rows in every column must have the same size unless they have SparseVectors
        sizes = col_max_sizes[col]
        if len(sizes) > 1 and not (SparseVector in all_col_types[col]):
            raise ValueError(
                'Rows of column {col} have varying sizes. This is only allowed if datatype is '
                'SparseVector or a mix of Sparse and DenseVector.'.format(col=col))
        col_max_sizes[col] = max(sizes)

    return all_col_types, col_shapes, col_max_sizes