def _read_partition_key()

in smallpond/logical/dataset.py [0:0]


    def _read_partition_key(path: str, data_partition_column: str, hive_partitioning: bool) -> int:
        """
        Get the partition key of the parquet file.

        Examples
        --------
        ```
        >>> ParquetDataSet._read_partition_key("output/000.parquet", "key", hive_partitioning=False)
        1
        >>> ParquetDataSet._read_partition_key("output/key=1/000.parquet", "key", hive_partitioning=True)
        1
        ```
        """

        def parse_partition_key(key: str):
            try:
                return int(key)
            except ValueError:
                logger.error(f"cannot parse partition key '{data_partition_column}' of {path} from: {key}")
                raise

        if hive_partitioning:
            path_part_prefix = data_partition_column + "="
            for part in path.split(os.path.sep):
                if part.startswith(path_part_prefix):
                    return parse_partition_key(part[len(path_part_prefix) :])
            raise RuntimeError(f"cannot extract hive partition key '{data_partition_column}' from path: {path}")

        with parquet.ParquetFile(path) as file:
            kv_metadata = file.schema_arrow.metadata or file.metadata.metadata
            if kv_metadata is not None:
                for key, val in kv_metadata.items():
                    key, val = key.decode("utf-8"), val.decode("utf-8")
                    if key == PARQUET_METADATA_KEY_PREFIX + data_partition_column:
                        return parse_partition_key(val)
            if file.metadata.num_rows == 0:
                logger.warning(f"cannot read partition keys from empty parquet file: {path}")
                return None
            for batch in file.iter_batches(batch_size=128, columns=[data_partition_column], use_threads=False):
                assert data_partition_column in batch.column_names, f"cannot find column '{data_partition_column}' in {batch.column_names}"
                assert batch.num_columns == 1, f"unexpected num of columns: {batch.column_names}"
                uniq_partition_keys = set(batch.columns[0].to_pylist())
                assert uniq_partition_keys and len(uniq_partition_keys) == 1, f"partition keys found in {path} not unique: {uniq_partition_keys}"
                return uniq_partition_keys.pop()