in src/nanotron/data/clm_collator.py [0:0]
def __call__(self, examples: List[Dict[str, List[np.ndarray]]]) -> Dict[str, Union[torch.Tensor, TensorPointer]]:
# Process the case when current rank doesn't require data
current_pp_rank = dist.get_rank(self.parallel_context.pp_pg)
if current_pp_rank not in [self.input_pp_rank, self.output_pp_rank]:
assert all(len(example) == 0 for example in examples)
return {
"input_ids": TensorPointer(group_rank=self.input_pp_rank),
"positions": TensorPointer(group_rank=self.input_pp_rank),
"label_ids": TensorPointer(group_rank=self.output_pp_rank),
"label_mask": TensorPointer(group_rank=self.output_pp_rank),
}
# input_ids[0,:20]
# array([ 198, 50, 30, 12532, 3589, 198, 51, 30, 30618,
# 198, 52, 30, 8279, 11274, 198, 21350, 42, 340,
# 0, 1780])
# position_ids[0,:20]
# array([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
# 17, 18, 0])
# result["label_ids"][0,:20]
# array([ 50, 30, 12532, 3589, 198, 51, 30, 30618, 198,
# 52, 30, 8279, 11274, 198, 21350, 42, 340, 0,
# 1780, 314])
# -> label_id for 0 is 1780 -> need to mask 1780
# result["label_mask"][0,:20]
# array([ True, True, True, True, True, True, True, True, True,
# True, True, True, True, True, True, True, True, True,
# False, True])
# document starts with first token, and last token is eos_token (0)
# label_mask should be 1 for all tokens except the last one
# Stack input_ids
input_ids = np.vstack([examples[i]["input_ids"] for i in range(len(examples))]) # (b, s)
batch_size, expanded_input_length = input_ids.shape
result: Dict[str, Union[np.ndarray, TensorPointer]] = {}
# Initialize all fields as TensorPointers
result["input_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["position_ids"] = TensorPointer(group_rank=self.input_pp_rank)
result["label_ids"] = TensorPointer(group_rank=self.output_pp_rank)
result["label_mask"] = TensorPointer(group_rank=self.output_pp_rank)
assert expanded_input_length == self.sequence_length + 1, (
f"Samples should be of length {self.sequence_length + 1} (seq_len+1), " f"but got {expanded_input_length}"
)
# Process inputs
if current_pp_rank == self.input_pp_rank:
result["input_ids"] = input_ids[:, :-1]
if "positions" in examples[0] and self.use_doc_masking:
# Use provided position_ids if available
position_ids = np.vstack([examples[i]["positions"] for i in range(len(examples))])
# Simply drop the last position ID for each example
result["positions"] = position_ids[:, :-1]
else:
# Default: sequential position ids
result["positions"] = np.arange(self.sequence_length)[None, :].repeat(batch_size, axis=0)
# Context Parallelism: Each CP rank gets a slice of the input_ids and position_ids
cp_rank, cp_size = dist.get_rank(self.parallel_context.cp_pg), self.parallel_context.context_parallel_size
local_slice = slice(
cp_rank * self.sequence_length // cp_size, (cp_rank + 1) * self.sequence_length // cp_size
)
result["input_ids"] = result["input_ids"][:, local_slice] # (b, s/cp_size)
if not self.cp_return_global_position_ids:
result["positions"] = result["positions"][:, local_slice] # (b, s/cp_size)
result["position_ids"] = result.pop("positions")
# Process labels
if current_pp_rank == self.output_pp_rank:
result["label_ids"] = input_ids[:, 1:]
# Create label mask based on position_ids
if "positions" in examples[0] and self.use_doc_masking:
# Get position_ids for the labels (shifted right by 1 to align with label_ids)
position_ids = np.vstack([examples[i]["positions"] for i in range(len(examples))])
position_ids = position_ids[:, 1:] # Shift right to align with labels
# Create mask: True for all tokens except the one before position_id == 0
result["label_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_)
# Find where position_ids is 0
zeros = position_ids == 0
# Mask the current token where we found zeros (since labels are already shifted right)
result["label_mask"] &= ~zeros
else:
# Default: all tokens are used for loss
result["label_mask"] = np.ones((batch_size, self.sequence_length), dtype=np.bool_)
# Context Parallelism: Each CP rank gets a slice of the label_ids and label_mask
local_slice = slice(
cp_rank * self.sequence_length // cp_size, (cp_rank + 1) * self.sequence_length // cp_size
)
result["label_ids"] = result["label_ids"][:, local_slice] # (b, s/cp_size)
result["label_mask"] = result["label_mask"][:, local_slice] # (b, s/cp_size)
# Validate shapes
if (
isinstance(result["input_ids"], torch.Tensor)
and result["input_ids"].shape[-1] != self.sequence_length // cp_size
):
raise ValueError(
f"`input_ids` are incorrectly preprocessed. Length is {result['input_ids'].shape[-1]}, but should be"
f" {self.sequence_length // cp_size}."
)
if (
isinstance(result["label_ids"], torch.Tensor)
and result["label_ids"].shape[-1] != result["input_ids"].shape[-1]
):
raise ValueError(
f"`label_ids` are incorrectly preprocessed. Length is {result['label_ids'].shape[-1]}, but should be"
f" {result['input_ids'].shape[-1]}."
)
# # Cast np.array to torch.Tensor
# result = {
# k: v if isinstance(v, TensorPointer) else torch.from_numpy(v).contiguous() for k, v in result.items()
# }
# # assert contiguous
# for k, v in result.items():
# if not isinstance(v, TensorPointer):
# assert v.is_contiguous(), f"{k} is not contiguous"
# assert not v.is_cuda, f"{k} is in cuda. Bad for pinning memory"
return result