in tzrec/datasets/utils.py [0:0]
def to(self, device: torch.device, non_blocking: bool = False) -> "Batch":
"""Copy to specified device."""
return Batch(
dense_features={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.dense_features.items()
},
sparse_features={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sparse_features.items()
},
sequence_mulval_lengths={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sequence_mulval_lengths.items()
},
sequence_dense_features={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sequence_dense_features.items()
},
labels={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.labels.items()
},
reserves=self.reserves,
tile_size=self.tile_size,
sample_weights={
k: v.to(device=device, non_blocking=non_blocking)
for k, v in self.sample_weights.items()
},
)