in petastorm/spark/spark_dataset_converter.py [0:0]
def _convert_precision(df, dtype):
if dtype is None:
return df
if dtype != "float32" and dtype != "float64":
raise ValueError("dtype {} is not supported. \
Use 'float32' or float64".format(dtype))
source_type, target_type = (DoubleType, FloatType) \
if dtype == "float32" else (FloatType, DoubleType)
logger.warning("Converting floating-point columns to %s", dtype)
for field in df.schema:
col_name = field.name
if isinstance(field.dataType, source_type):
df = df.withColumn(col_name, df[col_name].cast(target_type()))
elif isinstance(field.dataType, ArrayType) and \
isinstance(field.dataType.elementType, source_type):
df = df.withColumn(col_name, df[col_name].cast(ArrayType(target_type())))
return df