in src/nanotron/trainer.py [0:0]
def _update_dataloader_based_on_training_stages(self, dataloaders: Union[List[DataLoader], DataLoader]):
from collections.abc import Generator
if not hasattr(self.config, "data_stages") or self.config.data_stages is None:
if self.current_dataloader is None:
if isinstance(dataloaders, tuple):
dataloader = dataloaders[0]
else:
dataloader = dataloaders
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
self.current_base_dl = dataloader
return
elif isinstance(dataloaders, Generator):
# TODO(xrsrke): this is a hacky way to handle DoReMi's dataloader
# remove this in the next PR
self.current_dataloader = dataloaders
return
assert len(dataloaders) > 0, "No dataloaders provided"
def clear_dataloader_from_memory(dataloader: DataLoader, stage_name: str):
import gc
log_rank(
f"[Training Stage: {stage_name}] Clearing the previous training stage's dataloader and datasets from memory",
logger=logger,
level=logging.INFO,
)
self.current_base_dl = None
# NOTE: Clear dataloader from memory
del dataloader.dataset
del dataloader.sampler
del dataloader.batch_sampler
gc.collect()
dataloader = None
def find_stage_idx_to_resume():
reversed_data_stages = sorted(self.config.data_stages, key=lambda x: x.start_training_step, reverse=True)
for idx, stage in enumerate(reversed_data_stages):
if self.iteration_step >= stage.start_training_step:
return len(self.config.data_stages) - idx - 1
return None
stage_idx_to_resume = find_stage_idx_to_resume()
for stage_idx, stage in enumerate(self.config.data_stages):
if stage_idx < self.metadata.last_stage_idx:
continue
stage = cast(DatasetStageArgs, stage)
is_resume_from_training = self.current_dataloader is None and stage_idx_to_resume == stage_idx
if (stage.start_training_step == self.iteration_step) or is_resume_from_training:
if self.current_dataloader is not None:
prev_stage_name = self.config.data_stages[stage_idx - 1].name
prev_dataloader = dataloaders[prev_stage_name]
if isinstance(prev_dataloader, DataLoader):
# NOTE: we don't need to clear dummy data generator from memory
clear_dataloader_from_memory(prev_dataloader, stage_name=stage.name)
self.metadata.last_stage_idx = stage_idx
if is_resume_from_training:
remaining_train_steps = compute_remain_train_steps_of_a_data_stage_from_ckp(
stage, self.config, self.metadata
)
(
consumed_train_steps,
consumed_tokens_per_dataset_folder,
) = get_consumed_train_samples_of_a_data_stage_from_ckp(stage, self.metadata)
log_rank(
f"Resuming training from stage {stage.name}, it has trained for {consumed_train_steps} samples and has {remaining_train_steps} remaining train steps"
f"\nConsumed tokens per dataset folder: {pformat(consumed_tokens_per_dataset_folder)}",
logger=logger,
level=logging.INFO,
rank=0,
)
dataloader = dataloaders[stage.name]
# NOTE: if a dataloader is lazy initialized, we need to call it to initialize it
dataloader = dataloader() if callable(dataloader) else dataloader
break
if dataloader is not None:
self.current_dataloader = sanity_check_dataloader(
dataloader=dataloader, parallel_context=self.parallel_context, config=self.config
)
self.current_base_dl = dataloader