in torchrec/sparse/jagged_tensor_ops.cpp [180:206]
TupleOptionalFields KeyedJaggedTensor::to(
torch::Device device,
bool non_blocking) {
const auto values = values_.to(device, non_blocking);
c10::optional<torch::Tensor> weights = c10::nullopt;
if (weights_.has_value()) {
weights = weights_->to(device, non_blocking);
}
c10::optional<torch::Tensor> lengths = c10::nullopt;
if (lengths_.has_value()) {
lengths = lengths_->to(device, non_blocking);
}
c10::optional<torch::Tensor> offsets = c10::nullopt;
if (offsets_.has_value()) {
offsets = offsets_->to(device, non_blocking);
}
return std::make_tuple(
keys_,
values,
weights,
lengths,
offsets,
stride_,
length_per_key_,
offset_per_key_,
index_per_key_);
}