in trl/models/activation_offloading.py [0:0]
def pack_tensor(activation: torch.Tensor) -> int:
# activations are passed in during forward pass - from here we take over and return a unique id
if self.is_first_forward_call:
if len(self.tracker) != 0:
raise ValueError("Backward pass should have cleared tracker of all tensors")
# set training phase trackers
self.is_first_forward_call = False
self.is_first_backward_call = True
# query for basic tensor info
num_bytes = get_num_bytes_tensor(activation)
tensor_id = get_tensor_id()
# only offload hefty bois if they're activations on CUDA (our heuristic
# for that is to check if they're not params or buffers)!
if (
activation.device.type in ["cuda", "xpu"]
and num_bytes >= self.min_tensor_size_bytes
and (
not isinstance(activation, torch.nn.Parameter)
and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer))
)
):
if self.use_streams:
# First, sync back and dereference previously offloaded tensors
# as the offloading should be done sufficiently long ago.
for id in list(self.fwd_stash.keys()):
if id <= tensor_id - self.max_fwd_stash_size:
_, ev = self.fwd_stash[id]
self.s0.wait_event(ev)
del self.fwd_stash[id]
else:
break
# Sync in, offload, and add an event to sync back later
self.s1.wait_stream(self.s0)
stream = self.s1 if self.use_streams else self.s0
with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream):
cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
cpu_tensor.copy_(activation, non_blocking=True)
self.tracker[tensor_id] = (
cpu_tensor,
True, # True = (in future) modified
)
if self.use_streams:
event = self.s1.record_event()
# Stash to keep activation alive til s1 is done
self.fwd_stash[tensor_id] = (activation, event)
else:
self.tracker[tensor_id] = (
activation,
False,
) # False = not modified, tensor is as is
return tensor_id