in chatlearn/models/torch_module.py [0:0]
def offload(self, to_offload_weights=None, to_free_grad_buffers=None, to_offload_main_weights=None, to_offload_optimizer_states=None):
# The first time of calling `offload_weights` and `offload_main_weights` has a higher peak memory.
# So `free_grad_buffers` is called first to free memory, and `offload_weights` is called afterward
# to make more space for `offload_main_weights`.
if not (self.is_colocate or self.module_args.force_free_memory):
return
to_offload_weights = self._get_if_not_none(to_offload_weights, self.module_args.offload_weights)
to_offload_main_weights = self._get_if_not_none(to_offload_main_weights, self.module_args.offload_weights)
to_free_grad_buffers = self._get_if_not_none(to_free_grad_buffers, self.module_args.free_grad_buffers)
to_offload_optimizer_states = self._get_if_not_none(to_offload_optimizer_states, self.module_args.offload_optimizer_states)
if to_free_grad_buffers or to_offload_weights or to_offload_optimizer_states or to_offload_main_weights:
log_rank_0(get_full_proc_memory_info('Before offload'), self._logger)
torch.cuda.synchronize()
timer = self.timers(f'{self.name}_free_memory')
if not timer.started_:
timer.start()
torch.distributed.barrier()
if self.trainable:
if to_free_grad_buffers:
self.free_grad_buffers()
if to_offload_main_weights:
self.offload_main_weights()
if to_offload_optimizer_states:
self.offload_optimizer_states()
if to_offload_weights:
self.offload_weights()
torch.distributed.barrier()
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
timer.stop()
log_rank_0(get_full_proc_memory_info('After offload'), self._logger)