in petastorm/tools/copy_dataset.py [0:0]
def copy_dataset(spark, source_url, target_url, field_regex, not_null_fields, overwrite_output, partitions_count,
row_group_size_mb, hdfs_driver='libhdfs3'):
"""
Creates a copy of a dataset. A new dataset will optionally contain a subset of columns. Rows that have NULL
values in fields defined by ``not_null_fields`` argument are filtered out.
:param spark: An instance of ``SparkSession`` object
:param source_url: A url of the dataset to be copied.
:param target_url: A url specifying location of the target dataset.
:param field_regex: A list of regex patterns. Only columns that match one of these patterns are copied to the new
dataset.
:param not_null_fields: A list of fields that must have non-NULL valus in the target dataset.
:param overwrite_output: If ``False`` and there is an existing path defined by ``target_url``, the operation will
fail.
:param partitions_count: If not ``None``, the dataset is repartitioned before write. Number of files in the target
Parquet store is defined by this parameter.
:param row_group_size_mb: The size of the rowgroup in the target dataset. Specified in megabytes.
:param hdfs_driver: A string denoting the hdfs driver to use (if using a dataset on hdfs). Current choices are
libhdfs (java through JNI) or libhdfs3 (C++)
:param user: String denoting username when connecting to HDFS. None implies login user.
:return: None
"""
schema = get_schema_from_dataset_url(source_url, hdfs_driver=hdfs_driver)
fields = match_unischema_fields(schema, field_regex)
if field_regex and not fields:
field_names = list(schema.fields.keys())
raise ValueError('Regular expressions (%s) do not match any fields (%s)', str(field_regex), str(field_names))
if fields:
subschema = schema.create_schema_view(fields)
else:
subschema = schema
resolver = FilesystemResolver(target_url, spark.sparkContext._jsc.hadoopConfiguration(),
hdfs_driver=hdfs_driver, user=spark.sparkContext.sparkUser())
with materialize_dataset(spark, target_url, subschema, row_group_size_mb,
filesystem_factory=resolver.filesystem_factory()):
data_frame = spark.read \
.parquet(source_url)
if fields:
data_frame = data_frame.select(*[f.name for f in fields])
if not_null_fields:
not_null_condition = reduce(operator.__and__, (data_frame[f].isNotNull() for f in not_null_fields))
data_frame = data_frame.filter(not_null_condition)
if partitions_count:
data_frame = data_frame.repartition(partitions_count)
data_frame.write \
.mode('overwrite' if overwrite_output else 'error') \
.option('compression', 'none') \
.parquet(target_url)