in smallpond/execution/task.py [0:0]
def dump_output(self, output_iter: Iterable[StreamOutput]):
def write_table(writer: parquet.ParquetWriter, table: arrow.Table):
if table.num_rows == 0:
return
writer.write_table(table, self.parquet_row_group_size)
self.perf_metrics["num output rows"] += table.num_rows
self.add_elapsed_time("output dump time (secs)")
create_checkpoint = False
last_checkpoint_time = time.time() - self.random_float() * self.secs_checkpoint_interval / 2
output: StreamOutput = next(output_iter, None)
self.add_elapsed_time("compute time (secs)")
if output is None:
logger.warning(f"user's process method returns none")
return True
if self.parquet_row_group_size == DEFAULT_ROW_GROUP_SIZE:
# adjust row group size if it is not set by user
self.adjust_row_group_size(
self.streaming_batch_count * output.output_table.nbytes,
self.streaming_batch_count * output.output_table.num_rows,
)
output_iter = itertools.chain([output], output_iter)
buffered_output = output.output_table.slice(length=0)
for output_file_idx in itertools.count():
output_path = os.path.join(
self.runtime_output_abspath,
f"{self.output_filename}-{output_file_idx}.parquet",
)
output_file = open(output_path, "wb", buffering=32 * MB)
try:
with parquet.ParquetWriter(
where=output_file,
schema=buffered_output.schema.with_metadata(self.parquet_kv_metadata_bytes()),
use_dictionary=self.parquet_dictionary_encoding,
compression=(self.parquet_compression if self.parquet_compression is not None else "NONE"),
compression_level=self.parquet_compression_level,
write_batch_size=max(16 * 1024, self.parquet_row_group_size // 8),
data_page_size=max(64 * MB, self.parquet_row_group_bytes // 8),
) as writer:
while (output := next(output_iter, None)) is not None:
self.add_elapsed_time("compute time (secs)")
if buffered_output.num_rows + output.output_table.num_rows < self.parquet_row_group_size:
buffered_output = arrow.concat_tables((buffered_output, output.output_table))
else:
write_table(writer, buffered_output)
buffered_output = output.output_table
periodic_checkpoint = bool(output.batch_indices) and (time.time() - last_checkpoint_time) >= self.secs_checkpoint_interval
create_checkpoint = output.force_checkpoint or periodic_checkpoint
if create_checkpoint:
self.runtime_state.update_batch_offsets(output.batch_indices)
last_checkpoint_time = time.time()
break
if buffered_output is not None:
write_table(writer, buffered_output)
buffered_output = buffered_output.slice(length=0)
finally:
if isinstance(output_file, io.IOBase):
output_file.close()
assert buffered_output is None or buffered_output.num_rows == 0
self.runtime_state.streaming_output_paths.append(output_path)
if output is None:
break
if create_checkpoint and self.exec_cq is not None:
checkpoint = copy.copy(self)
checkpoint.clean_complex_attrs()
self.exec_cq.push(checkpoint, buffering=False)
logger.debug(
f"created and sent checkpoint #{self.runtime_state.max_batch_offsets}/{self.streaming_batch_count}: {self.runtime_state}"
)
return True