in trl/models/activation_offloading.py [0:0]
def unpack_tensor_with_streams(unpack_tensor_id: int) -> torch.Tensor:
# backward pass - we are called with the tensor_id, which
# we will use to retrieve the saved/offloaded tensor
if self.is_first_backward_call:
self.curr_graph_id = torch._C._current_graph_task_id()
def wait_and_del_remaining_references() -> None:
for id in list(self.bwd_tensor_stash.keys()):
event = self.bwd_ev_stash[id]
self.s1.wait_event(event)
del self.bwd_tensor_stash[id]
# Register a callback to the end of autograd to clean everything up
torch.autograd.variable.Variable._execution_engine.queue_callback(wait_and_del_remaining_references)
if self.is_first_forward_pass:
self.is_first_forward_pass = False
if self.use_pin_memory:
verify_sufficient_virtual_memory()
self.is_first_backward_call = False
self.is_first_forward_call = True
if unpack_tensor_id not in self.tracker:
raise ValueError(f"untracked tensor with id {unpack_tensor_id}")
maybe_accelerator_tensor, modified = self.tracker[unpack_tensor_id]
if modified:
# Get data on the current autograd node
graph_id = torch._C._current_graph_task_id()
node = torch._C._current_autograd_node()
prev_node_ids = []
# If we're on a new node, mark prev node's tensors to be freed later
if graph_id == self.curr_graph_id and self.curr_autograd_node != node:
self.curr_autograd_node = node
prev_node_ids = list(self.bwd_tensor_stash.keys())
brought_back_from_cpu = True
if unpack_tensor_id in self.fwd_stash:
maybe_accelerator_tensor = self.fwd_stash[unpack_tensor_id][0]
brought_back_from_cpu = False
else:
# Kick off the process to bring tensors back
with self.s1 if self.accelerator_type == "xpu" else torch.cuda.stream(self.s1):
accelerator_tensor = maybe_accelerator_tensor.to(self.accelerator_type, non_blocking=True)
maybe_accelerator_tensor = accelerator_tensor
# Tell comp stream to wait for the info to be loaded before executing
self.s0.wait_stream(self.s1)
# Stash the tensor to keep memory alive until compute stream is complete
self.bwd_tensor_stash[unpack_tensor_id] = maybe_accelerator_tensor
# Note: [Track views of the unpacked]
# Why do we get the use count of the unpacked tensor here? We want an
# initial count to compare to later, during the post-hook of the
# backward node, when we need to decide whether we're allowed to free
# the tensor yet. In what obscure cases must we delay freeing the
# tensor (and thus call record_stream)?
# 1. Any of the outputs of the backward node is a view of the unpacked
# tensor.
# 2. In the case that this unpacked tensor will be used in a
# checkpointed region, if one of the recomputed saved tensors ends
# up as a view of the unpacked tensor.
# 3. The user abuses the system somehow and manually relies on the
# unpacked tensor to exist after the backward node has executed.
storage_refcount = torch._C._storage_Use_Count(maybe_accelerator_tensor.untyped_storage()._cdata)
def hook(outputs, inputs):
# create events for the current node inputs/outputs if they were streamed in
if brought_back_from_cpu:
# See Note: [Track views of the unpacked]
# IF any of the outputs is a view of the tensor, OR if a view of
# the tensor has been saved as a part of checkpoint's recompute
# process, OR the user has abusedly incurred a reference on the
# unpacked tensor, THEN the tensor might be used later and we
# cannot presume to delete it after only the current node is
# done! So we use our frenemy, record_stream, to ensure the
# Tensor stays unmessed with until it's done getting used in the
# compute stream (s0 here). Note that the con here is we introduce
# non-deterministic (thus higher) memory usage, but this case
# should not happen often.
unpacked_tensor = self.bwd_tensor_stash[unpack_tensor_id]
if torch._C._storage_Use_Count(unpacked_tensor.untyped_storage()._cdata) > storage_refcount:
unpacked_tensor.record_stream(self.s0)
del self.bwd_tensor_stash[unpack_tensor_id]
else:
event = self.s0.record_event()
self.bwd_ev_stash[unpack_tensor_id] = event
# if there are still things in the fwd_stash, get rid of them as we're in bwd now
for id in list(self.fwd_stash.keys()):
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[id]
# wait on prev node's events and del those
for id in prev_node_ids:
event = self.bwd_ev_stash[id]
self.s1.wait_event(event)
del self.bwd_tensor_stash[id]
return outputs
node.register_hook(hook)
# clear tensor from tracking
del self.tracker[unpack_tensor_id]
return maybe_accelerator_tensor