def partition()

in smallpond/execution/task.py [0:0]


    def partition(self, batch_index: int, input_dataset: ParquetDataSet):
        import polars

        self.add_elapsed_time()
        table = input_dataset.to_arrow_table(max_workers=self.cpu_limit)
        self.perf_metrics["num input rows"] += table.num_rows
        elapsed_time = self.add_elapsed_time("input load time (secs)")
        logger.debug(f"load input dataset: {table.nbytes/MB:.3f}MB, {table.num_rows} rows, {elapsed_time:.3f} secs")

        if self.shuffle_only:
            partition_keys = table.column(self.data_partition_column)
        elif self.random_shuffle:
            partition_keys = arrow.array(self.numpy_random_gen.integers(self.npartitions, size=table.num_rows))
        else:
            hash_columns = polars.from_arrow(table.select(self.hash_columns))
            hash_values = hash_columns.hash_rows(*self.fixed_rand_seeds)
            partition_keys = (hash_values % self.npartitions).to_arrow()

        if self.data_partition_column in table.column_names:
            table = table.drop_columns(self.data_partition_column)
        table = table.append_column(self.data_partition_column, partition_keys)
        elapsed_time = self.add_elapsed_time("compute time (secs)")
        logger.debug(f"generate partition keys: {elapsed_time:.3f} secs")

        table_slice_size = max(DEFAULT_BATCH_SIZE, min(table.num_rows // 2, 100 * 1024 * 1024))
        num_iterations = math.ceil(table.num_rows / table_slice_size)

        def write_partition_data(
            partition_batch: List[Tuple[int, polars.DataFrame]],
        ) -> int:
            total_num_rows = 0
            for partition_idx, partition_data in partition_batch:
                total_num_rows += len(partition_data)
                self._write_to_partition(partition_idx, partition_data.to_arrow())
            return total_num_rows

        for table_slice_idx, table_slice_offset in enumerate(range(0, table.num_rows, table_slice_size)):
            table_slice = table.slice(table_slice_offset, table_slice_size)
            logger.debug(f"table slice #{table_slice_idx+1}/{num_iterations}: {table_slice.nbytes/MB:.3f}MB, {table_slice.num_rows} rows")

            df = polars.from_arrow(table_slice)
            del table_slice
            elapsed_time = self.add_elapsed_time("compute time (secs)")
            logger.debug(f"convert from arrow table #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs")

            partitioned_dfs = df.partition_by(
                [self.data_partition_column],
                maintain_order=False,
                include_key=not self.drop_partition_column,
                as_dict=True,
            )
            partitioned_dfs = [(partition_idx, df) for (partition_idx,), df in partitioned_dfs.items()]
            del df
            elapsed_time = self.add_elapsed_time("compute time (secs)")
            logger.debug(f"build partition data #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs")

            partition_batches = split_into_rows(partitioned_dfs, self.num_workers)
            self.perf_metrics["num output rows"] += sum(self.io_workers.map(write_partition_data, partition_batches))
            elapsed_time = self.add_elapsed_time("output dump time (secs)")
            logger.debug(f"write partition data #{table_slice_idx+1}/{num_iterations}: {elapsed_time:.3f} secs")