torch::Tensor ArrowArray2Tensor()

in graphlearn_torch/v6d/vineyard_utils.cc [146:171]


torch::Tensor ArrowArray2Tensor(
  std::shared_ptr<arrow::Array> fscol, uint64_t col_num) {
  if (fscol->type()->Equals(arrow::int32())) {
    auto mcol = std::dynamic_pointer_cast<arrow::Int32Array>(fscol);
    auto options = torch::TensorOptions().dtype(torch::kI32).device(torch::kCPU);
    return torch::from_blob(const_cast<int32_t*>(mcol->raw_values()),
      {static_cast<int64_t>(mcol->length() / col_num), static_cast<int64_t>(col_num)}, options);
  } else if (fscol->type()->Equals(arrow::int64())) {
    auto mcol = std::dynamic_pointer_cast<arrow::Int64Array>(fscol);
    auto options = torch::TensorOptions().dtype(torch::kI64).device(torch::kCPU);
    return torch::from_blob(const_cast<int64_t*>(mcol->raw_values()),
      {static_cast<int64_t>(mcol->length() / col_num), static_cast<int64_t>(col_num)}, options);
  } else if (fscol->type()->Equals(arrow::float32())) { //dtype: float
    auto mcol = std::dynamic_pointer_cast<arrow::FloatArray>(fscol);
    auto options = torch::TensorOptions().dtype(torch::kF32).device(torch::kCPU);
    return torch::from_blob(const_cast<float*>(mcol->raw_values()),
      {static_cast<int64_t>(mcol->length() / col_num), static_cast<int64_t>(col_num)}, options);
  } else if (fscol->type()->Equals(arrow::float64())){ //dtype: double
    auto mcol = std::dynamic_pointer_cast<arrow::DoubleArray>(fscol);
    auto options = torch::TensorOptions().dtype(torch::kF64).device(torch::kCPU);
    return torch::from_blob(const_cast<double*>(mcol->raw_values()),
      {static_cast<int64_t>(mcol->length() / col_num), static_cast<int64_t>(col_num)}, options);
  } else {
    throw std::runtime_error("Unsupported column type: " + fscol->type()->ToString());
  }
}