in smallpond/logical/dataset.py [0:0]
def _read_partition_key(path: str, data_partition_column: str, hive_partitioning: bool) -> int:
"""
Get the partition key of the parquet file.
Examples
--------
```
>>> ParquetDataSet._read_partition_key("output/000.parquet", "key", hive_partitioning=False)
1
>>> ParquetDataSet._read_partition_key("output/key=1/000.parquet", "key", hive_partitioning=True)
1
```
"""
def parse_partition_key(key: str):
try:
return int(key)
except ValueError:
logger.error(f"cannot parse partition key '{data_partition_column}' of {path} from: {key}")
raise
if hive_partitioning:
path_part_prefix = data_partition_column + "="
for part in path.split(os.path.sep):
if part.startswith(path_part_prefix):
return parse_partition_key(part[len(path_part_prefix) :])
raise RuntimeError(f"cannot extract hive partition key '{data_partition_column}' from path: {path}")
with parquet.ParquetFile(path) as file:
kv_metadata = file.schema_arrow.metadata or file.metadata.metadata
if kv_metadata is not None:
for key, val in kv_metadata.items():
key, val = key.decode("utf-8"), val.decode("utf-8")
if key == PARQUET_METADATA_KEY_PREFIX + data_partition_column:
return parse_partition_key(val)
if file.metadata.num_rows == 0:
logger.warning(f"cannot read partition keys from empty parquet file: {path}")
return None
for batch in file.iter_batches(batch_size=128, columns=[data_partition_column], use_threads=False):
assert data_partition_column in batch.column_names, f"cannot find column '{data_partition_column}' in {batch.column_names}"
assert batch.num_columns == 1, f"unexpected num of columns: {batch.column_names}"
uniq_partition_keys = set(batch.columns[0].to_pylist())
assert uniq_partition_keys and len(uniq_partition_keys) == 1, f"partition keys found in {path} not unique: {uniq_partition_keys}"
return uniq_partition_keys.pop()