def offload()

in chatlearn/models/torch_module.py [0:0]


    def offload(self, to_offload_weights=None, to_free_grad_buffers=None, to_offload_main_weights=None, to_offload_optimizer_states=None):
        # The first time of calling `offload_weights` and `offload_main_weights` has a higher peak memory.
        # So `free_grad_buffers` is called first to free memory, and `offload_weights` is called afterward
        # to make more space for `offload_main_weights`.
        if not (self.is_colocate or self.module_args.force_free_memory):
            return
        to_offload_weights = self._get_if_not_none(to_offload_weights, self.module_args.offload_weights)
        to_offload_main_weights = self._get_if_not_none(to_offload_main_weights, self.module_args.offload_weights)
        to_free_grad_buffers = self._get_if_not_none(to_free_grad_buffers, self.module_args.free_grad_buffers)
        to_offload_optimizer_states = self._get_if_not_none(to_offload_optimizer_states, self.module_args.offload_optimizer_states)
        if to_free_grad_buffers or to_offload_weights or to_offload_optimizer_states or to_offload_main_weights:
            log_rank_0(get_full_proc_memory_info('Before offload'), self._logger)
            torch.cuda.synchronize()
            timer = self.timers(f'{self.name}_free_memory')
            if not timer.started_:
                timer.start()
            torch.distributed.barrier()
            if self.trainable:
                if to_free_grad_buffers:
                    self.free_grad_buffers()
                if to_offload_main_weights:
                    self.offload_main_weights()
                if to_offload_optimizer_states:
                    self.offload_optimizer_states()
            if to_offload_weights:
                self.offload_weights()
            torch.distributed.barrier()
            torch.cuda.synchronize()
            gc.collect()
            torch.cuda.empty_cache()
            torch.cuda.reset_peak_memory_stats()
            timer.stop()
            log_rank_0(get_full_proc_memory_info('After offload'), self._logger)