def onload()

in chatlearn/models/torch_module.py [0:0]


    def onload(self, to_onload_weights=None, to_build_grad_buffers=None, to_onload_main_weights=None, to_onload_optimizer_states=None):
        if not (self.is_colocate or self.module_args.force_free_memory):
            return
        to_onload_weights = self._get_if_not_none(to_onload_weights, self.module_args.offload_weights)
        to_build_grad_buffers = self._get_if_not_none(to_build_grad_buffers, self.module_args.free_grad_buffers)
        to_onload_main_weights = self._get_if_not_none(to_onload_main_weights, self.module_args.offload_weights)
        to_onload_optimizer_states = self._get_if_not_none(to_onload_optimizer_states, self.module_args.offload_optimizer_states)
        if to_onload_weights or to_build_grad_buffers or to_onload_main_weights or to_onload_optimizer_states:
            log_rank_0(get_full_proc_memory_info('Before onload'), self._logger)
            torch.cuda.synchronize()
            timer = self.timers(f'{self.name}_free_memory')
            if not timer.started_:
                timer.start()
            torch.distributed.barrier()
            if to_onload_weights:
                self.onload_weights()
            if self.trainable:
                if to_build_grad_buffers:
                    self.build_grad_buffers()
                if to_onload_main_weights:
                    self.onload_main_weights()
                if to_onload_optimizer_states:
                    self.onload_optimizer_states()
            torch.distributed.barrier()
            torch.cuda.synchronize()
            torch.cuda.empty_cache()
            gc.collect()
            timer.stop()
            log_rank_0(get_full_proc_memory_info('After onload'), self._logger)