def copy_dataset()

in petastorm/tools/copy_dataset.py [0:0]


def copy_dataset(spark, source_url, target_url, field_regex, not_null_fields, overwrite_output, partitions_count,
                 row_group_size_mb, hdfs_driver='libhdfs3'):
    """
    Creates a copy of a dataset. A new dataset will optionally contain a subset of columns. Rows that have NULL
    values in fields defined by ``not_null_fields`` argument are filtered out.


    :param spark: An instance of ``SparkSession`` object
    :param source_url: A url of the dataset to be copied.
    :param target_url: A url specifying location of the target dataset.
    :param field_regex: A list of regex patterns. Only columns that match one of these patterns are copied to the new
      dataset.
    :param not_null_fields: A list of fields that must have non-NULL valus in the target dataset.
    :param overwrite_output: If ``False`` and there is an existing path defined by ``target_url``, the operation will
      fail.
    :param partitions_count: If not ``None``, the dataset is repartitioned before write. Number of files in the target
      Parquet store is defined by this parameter.
    :param row_group_size_mb: The size of the rowgroup in the target dataset. Specified in megabytes.
    :param hdfs_driver: A string denoting the hdfs driver to use (if using a dataset on hdfs). Current choices are
        libhdfs (java through JNI) or libhdfs3 (C++)
    :param user: String denoting username when connecting to HDFS. None implies login user.
    :return: None
    """
    schema = get_schema_from_dataset_url(source_url, hdfs_driver=hdfs_driver)

    fields = match_unischema_fields(schema, field_regex)

    if field_regex and not fields:
        field_names = list(schema.fields.keys())
        raise ValueError('Regular expressions (%s) do not match any fields (%s)', str(field_regex), str(field_names))

    if fields:
        subschema = schema.create_schema_view(fields)
    else:
        subschema = schema

    resolver = FilesystemResolver(target_url, spark.sparkContext._jsc.hadoopConfiguration(),
                                  hdfs_driver=hdfs_driver, user=spark.sparkContext.sparkUser())
    with materialize_dataset(spark, target_url, subschema, row_group_size_mb,
                             filesystem_factory=resolver.filesystem_factory()):
        data_frame = spark.read \
            .parquet(source_url)

        if fields:
            data_frame = data_frame.select(*[f.name for f in fields])

        if not_null_fields:
            not_null_condition = reduce(operator.__and__, (data_frame[f].isNotNull() for f in not_null_fields))
            data_frame = data_frame.filter(not_null_condition)

        if partitions_count:
            data_frame = data_frame.repartition(partitions_count)

        data_frame.write \
            .mode('overwrite' if overwrite_output else 'error') \
            .option('compression', 'none') \
            .parquet(target_url)