in trl/data_utils.py [0:0]
def _pack_ffd(examples: pa.Table, seq_length: int) -> pa.Table:
"""Pack sequences in a pyarrow Table using First Fit Decreasing strategy."""
# Add position_ids to the examples
input_ids = examples["input_ids"]
position_ids_python = [list(range(len(sequence))) for sequence in input_ids.to_pylist()]
position_ids_array = pa.array(position_ids_python, type=examples["input_ids"].type)
examples = examples.append_column("position_ids", position_ids_array)
columns = []
list_column_idx = None
for idx, column in enumerate(examples.columns):
if pyarrow.types.is_list(column.type) or pyarrow.types.is_large_list(column.type):
column = pc.list_slice(column, 0, seq_length)
if list_column_idx is None:
list_column_idx = idx
columns.append(column)
examples = pa.Table.from_arrays(columns, names=examples.column_names)
ids = np.arange(len(examples))
assert list_column_idx is not None
lengths = pc.make_struct(pc.list_value_length(examples[list_column_idx]).combine_chunks(), ids)
lengths = lengths.sort("descending", by=0)
segment_tree = _SegmentTree(seq_length)
segment_tree.add(seq_length) # the max, `seq_length` bin is always available
space_to_bin = defaultdict(deque)
# Bin is represented as a dict (of example ids and sum of their lengths) to allow in-place updates
bins: list[dict] = []
for length, idx in zip(lengths.field(0).to_numpy(), lengths.field(1).to_numpy()):
space = segment_tree.search(length)
if space < seq_length:
bin = space_to_bin[space].popleft()
else:
bin = {"ids": [], "length": 0}
bins.append(bin)
bin["ids"].append(idx)
bin["length"] += length
if space < seq_length and not space_to_bin[space]:
segment_tree.remove(space)
space = space - length
space_to_bin[space].append(bin)
if space > 0:
segment_tree.add(space)
examples = pc.take(examples, [id_ for bin in bins for id_ in bin["ids"]])
offsets = np.array([0] + [bin["length"] for bin in bins])
offsets = np.cumsum(offsets)
columns = []
for column in examples.columns:
assert len(column.chunks) == 1 # `pc.take` returns a ChunkedArray with a single chunk
column = column.chunks[0]
if pa.types.is_list(column.type) or pa.types.is_large_list(column.type):
dtype = column.offsets.type.to_pandas_dtype()
column = type(column).from_arrays(offsets.astype(dtype), column.values)
columns.append(column)
return pa.Table.from_arrays(columns, names=examples.column_names)