in tzrec/datasets/utils.py [0:0]
def pin_memory(self) -> "Batch":
"""Copy to pinned memory."""
# TODO(hongsheng.jhs): KeyedTensor do not have pin_memory()
dense_features = {}
for k, v in self.dense_features.items():
dense_features[k] = KeyedTensor(
keys=v.keys(),
length_per_key=v.length_per_key(),
values=v.values().pin_memory(),
key_dim=v.key_dim(),
)
sequence_dense_features = {}
for k, v in self.sequence_dense_features.items():
weights = v._weights
lengths = v._lengths
offsets = v._offsets
sequence_dense_features[k] = JaggedTensor(
values=v.values().pin_memory(),
weights=weights.pin_memory() if weights is not None else None,
lengths=lengths.pin_memory() if lengths is not None else None,
offsets=offsets.pin_memory() if offsets is not None else None,
)
return Batch(
dense_features=dense_features,
sparse_features={
k: v.pin_memory() for k, v in self.sparse_features.items()
},
sequence_mulval_lengths={
k: v.pin_memory() for k, v in self.sequence_mulval_lengths.items()
},
sequence_dense_features=sequence_dense_features,
labels={k: v.pin_memory() for k, v in self.labels.items()},
reserves=self.reserves,
tile_size=self.tile_size,
sample_weights={k: v.pin_memory() for k, v in self.sample_weights.items()},
)