in chatlearn/models/torch_module.py [0:0]
def onload(self, to_onload_weights=None, to_build_grad_buffers=None, to_onload_main_weights=None, to_onload_optimizer_states=None):
if not (self.is_colocate or self.module_args.force_free_memory):
return
to_onload_weights = self._get_if_not_none(to_onload_weights, self.module_args.offload_weights)
to_build_grad_buffers = self._get_if_not_none(to_build_grad_buffers, self.module_args.free_grad_buffers)
to_onload_main_weights = self._get_if_not_none(to_onload_main_weights, self.module_args.offload_weights)
to_onload_optimizer_states = self._get_if_not_none(to_onload_optimizer_states, self.module_args.offload_optimizer_states)
if to_onload_weights or to_build_grad_buffers or to_onload_main_weights or to_onload_optimizer_states:
log_rank_0(get_full_proc_memory_info('Before onload'), self._logger)
torch.cuda.synchronize()
timer = self.timers(f'{self.name}_free_memory')
if not timer.started_:
timer.start()
torch.distributed.barrier()
if to_onload_weights:
self.onload_weights()
if self.trainable:
if to_build_grad_buffers:
self.build_grad_buffers()
if to_onload_main_weights:
self.onload_main_weights()
if to_onload_optimizer_states:
self.onload_optimizer_states()
torch.distributed.barrier()
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()
timer.stop()
log_rank_0(get_full_proc_memory_info('After onload'), self._logger)