inline at::Tensor unUnbind()

in fairring/utils.h [352:375]


inline at::Tensor unUnbind(const std::vector<at::Tensor>& ts) {
  MY_CHECK(!ts.empty());
  for (const auto idx : c10::irange(ts.size())) {
    const at::Tensor& t = ts[idx];
    MY_CHECK(t.layout() == at::kStrided);
    MY_CHECK(t.is_non_overlapping_and_dense());
    MY_CHECK(t.numel() == ts[0].numel());
    MY_CHECK(t.storage() == ts[0].storage());
    MY_CHECK(
        t.storage_offset() ==
        ts[0].storage_offset() + static_cast<int64_t>(idx) * ts[0].numel());
    MY_CHECK(t.key_set() == ts[0].key_set());
    MY_CHECK(t.dtype() == ts[0].dtype());
  }
  auto catTImpl = c10::make_intrusive<c10::TensorImpl>(
      c10::TensorImpl::VIEW,
      c10::Storage(ts[0].storage()),
      ts[0].key_set(),
      ts[0].dtype());
  catTImpl->set_storage_offset(ts[0].storage_offset());
  catTImpl->set_sizes_contiguous(
      {ts[0].numel() * static_cast<int64_t>(ts.size())});
  return at::Tensor(std::move(catTImpl));
}