def _convert_precision()

in petastorm/spark/spark_dataset_converter.py [0:0]


def _convert_precision(df, dtype):
    if dtype is None:
        return df

    if dtype != "float32" and dtype != "float64":
        raise ValueError("dtype {} is not supported. \
            Use 'float32' or float64".format(dtype))

    source_type, target_type = (DoubleType, FloatType) \
        if dtype == "float32" else (FloatType, DoubleType)

    logger.warning("Converting floating-point columns to %s", dtype)

    for field in df.schema:
        col_name = field.name
        if isinstance(field.dataType, source_type):
            df = df.withColumn(col_name, df[col_name].cast(target_type()))
        elif isinstance(field.dataType, ArrayType) and \
                isinstance(field.dataType.elementType, source_type):
            df = df.withColumn(col_name, df[col_name].cast(ArrayType(target_type())))
    return df