in smallpond/execution/task.py [0:0]
def run(self) -> bool:
self.add_elapsed_time()
if self.skip_when_any_input_empty:
return True
input_row_ranges = [dataset.resolved_row_ranges for dataset in self.input_datasets if isinstance(dataset, ParquetDataSet)]
input_byte_size = [sum(row_range.estimated_data_size for row_range in row_ranges) for row_ranges in input_row_ranges]
input_num_rows = [sum(row_range.num_rows for row_range in row_ranges) for row_ranges in input_row_ranges]
input_files = [set(row_range.path for row_range in row_ranges) for row_ranges in input_row_ranges]
self.perf_metrics["num input rows"] += sum(input_num_rows)
self.perf_metrics["input data size (MB)"] += sum(input_byte_size) / MB
# calculate the max streaming batch size based on memory limit
avg_input_row_size = sum(self.compute_avg_row_size(nbytes, num_rows) for nbytes, num_rows in zip(input_byte_size, input_num_rows))
max_batch_rows = self.max_batch_size // avg_input_row_size
if self.runtime_state is None:
if self.streaming_batch_size > max_batch_rows:
logger.warning(
f"reduce streaming batch size from {self.streaming_batch_size} to {max_batch_rows} (approx. {self.max_batch_size/GB:.3f}GB)"
)
self.streaming_batch_size = max_batch_rows
self.streaming_batch_count = max(
1,
max(map(len, input_files)),
math.ceil(max(input_num_rows) / self.streaming_batch_size),
)
else:
self.streaming_batch_size = self.runtime_state.streaming_batch_size
self.streaming_batch_count = self.runtime_state.streaming_batch_count
try:
conn = None
if self.use_duckdb_reader:
conn = duckdb.connect(database=":memory:", config={"allow_unsigned_extensions": "true"})
self.prepare_connection(conn)
input_readers = [
dataset.to_batch_reader(
batch_size=self.streaming_batch_size,
conn=conn,
)
for dataset in self.input_datasets
]
if self.runtime_state is None:
self.runtime_state = ArrowStreamTask.RuntimeState(self.streaming_batch_size, self.streaming_batch_count)
else:
self.restore_input_state(self.runtime_state, input_readers)
self.runtime_state.last_batch_indices = None
output_iter = self._call_process(self.ctx.set_current_task(self), input_readers)
self.add_elapsed_time("compute time (secs)")
if self.background_io_thread:
with ConcurrentIter(output_iter) as concurrent_iter:
return self.dump_output(concurrent_iter)
else:
return self.dump_output(output_iter)
except arrow.lib.ArrowMemoryError as ex:
raise OutOfMemory(f"{self.key} failed with OOM error") from ex
finally:
if conn is not None:
conn.close()