def pack_tensor()

in trl/models/activation_offloading.py [0:0]


        def pack_tensor(activation: torch.Tensor) -> int:
            # activations are passed in during forward pass - from here we take over and return a unique id
            if self.is_first_forward_call:
                if len(self.tracker) != 0:
                    raise ValueError("Backward pass should have cleared tracker of all tensors")

                # set training phase trackers
                self.is_first_forward_call = False
                self.is_first_backward_call = True

            # query for basic tensor info
            num_bytes = get_num_bytes_tensor(activation)
            tensor_id = get_tensor_id()

            # only offload hefty bois if they're activations on CUDA (our heuristic
            # for that is to check if they're not params or buffers)!
            if (
                activation.device.type in ["cuda", "xpu"]
                and num_bytes >= self.min_tensor_size_bytes
                and (
                    not isinstance(activation, torch.nn.Parameter)
                    and not (hasattr(torch.nn, "Buffer") and isinstance(activation, torch.nn.Buffer))
                )
            ):
                if self.use_streams:
                    # First, sync back and dereference previously offloaded tensors
                    # as the offloading should be done sufficiently long ago.
                    for id in list(self.fwd_stash.keys()):
                        if id <= tensor_id - self.max_fwd_stash_size:
                            _, ev = self.fwd_stash[id]
                            self.s0.wait_event(ev)
                            del self.fwd_stash[id]
                        else:
                            break

                    # Sync in, offload, and add an event to sync back later
                    self.s1.wait_stream(self.s0)

                stream = self.s1 if self.use_streams else self.s0
                with stream if self.accelerator_type == "xpu" else torch.cuda.stream(stream):
                    cpu_tensor = torch.empty_like(activation, pin_memory=self.use_pin_memory, device="cpu")
                    cpu_tensor.copy_(activation, non_blocking=True)
                    self.tracker[tensor_id] = (
                        cpu_tensor,
                        True,  # True = (in future) modified
                    )

                if self.use_streams:
                    event = self.s1.record_event()

                    # Stash to keep activation alive til s1 is done
                    self.fwd_stash[tensor_id] = (activation, event)
            else:
                self.tracker[tensor_id] = (
                    activation,
                    False,
                )  # False = not modified, tensor is as is

            return tensor_id