def _pack_ffd()

in trl/data_utils.py [0:0]


def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:
    """Pack sequences in a pyarrow Table using First Fit Decreasing strategy."""
    # Add position_ids to the examples
    input_ids = examples["input_ids"]
    position_ids_python = [list(range(len(sequence))) for sequence in input_ids.to_pylist()]
    position_ids_array = pa.array(position_ids_python, type=examples["input_ids"].type)
    examples = examples.append_column("position_ids", position_ids_array)

    columns = []
    list_column_idx = None
    for idx, column in enumerate(examples.columns):
        if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
            column = pc.list_slice(column, 0, seq_length)
            if list_column_idx is None:
                list_column_idx = idx
        columns.append(column)
    examples = pa.Table.from_arrays(columns, names=examples.column_names)

    ids = np.arange(len(examples))
    assert list_column_idx is not None
    lengths = pc.make_struct(pc.list_value_length(examples[list_column_idx]).combine_chunks(), ids)
    lengths = lengths.sort("descending", by=0)

    segment_tree = _SegmentTree(seq_length)
    segment_tree.add(seq_length)  # the max, `seq_length` bin is always available
    space_to_bin = defaultdict(deque)

    # Bin is represented as a dict (of example ids and sum of their lengths) to allow in-place updates
    bins: list[dict] = []
    for length, idx in zip(lengths.field(0).to_numpy(), lengths.field(1).to_numpy()):
        space = segment_tree.search(length)

        if space < seq_length:
            bin = space_to_bin[space].popleft()
        else:
            bin = {"ids": [], "length": 0}
            bins.append(bin)

        bin["ids"].append(idx)
        bin["length"] += length
        if space < seq_length and not space_to_bin[space]:
            segment_tree.remove(space)

        space = space - length
        space_to_bin[space].append(bin)
        if space > 0:
            segment_tree.add(space)

    examples = pc.take(examples, [id_ for bin in bins for id_ in bin["ids"]])
    offsets = np.array([0] + [bin["length"] for bin in bins])
    offsets = np.cumsum(offsets)

    columns = []
    for column in examples.columns:
        assert len(column.chunks) == 1  # `pc.take` returns a ChunkedArray with a single chunk
        column = column.chunks[0]
        if pa.types.is_list(column.type) or pa.types.is_large_list(column.type):
            dtype = column.offsets.type.to_pandas_dtype()
            column = type(column).from_arrays(offsets.astype(dtype), column.values)
        columns.append(column)
    return pa.Table.from_arrays(columns, names=examples.column_names)