absl::Status CooFromListArray()

in tfx_bsl/cc/arrow/array_util.cc [691:804]


absl::Status CooFromListArray(
    const std::shared_ptr<arrow::Array>& list_array,
    std::shared_ptr<arrow::Array>* coo_array,
    std::shared_ptr<arrow::Array>* dense_shape_array) {
  // A ListArray encodes its list structure using offsets, or "row splits"
  // where [row_splits[i], row_splits[i+1]) are the indices of values of
  // the i-th sub-list. For example:
  // [[a, b], [], [c]] is encoded as:
  // value_array: [a, b, c]
  // row_splits: [0, 2, 2, 3]
  // A k-nested ListArray is encoded recursively as a row_splits array
  // and a (k-1)-Nested ListArray (or a primitive array if k==0). A 1-nested
  // ListArray is a ListArray<primitive>.

  std::vector<RowSplitsRep> nested_row_splits;
  std::array<int32_t, 2> dummy_outermost_row_splits = {
      0, static_cast<int32_t>(list_array->length())};
  nested_row_splits.push_back(
      RowSplitsRep(absl::MakeSpan(dummy_outermost_row_splits)));

  // Strip `list_array` and populate `nested_row_splits` with row_splits of
  // each level.
  std::shared_ptr<arrow::Array> values = list_array;
  while (true) {
    bool is_list_array = true;
    switch (values->type()->id()) {
      case arrow::Type::LIST: {
        ListArray* list_array = static_cast<ListArray*>(values.get());
        RowSplitsRep row_splits(*list_array);
        nested_row_splits.push_back(row_splits);
        // Note that the values array is not sliced even if `list_array` is, so
        // we slice it here.
        values = list_array->values()->Slice(
            row_splits.front(), row_splits.back() - row_splits.front());
        break;
      }
      case arrow::Type::LARGE_LIST: {
        LargeListArray* list_array = static_cast<LargeListArray*>(values.get());
        RowSplitsRep row_splits(*list_array);
        nested_row_splits.push_back(row_splits);
        // Note that the values array is not sliced even if `list_array` is, so
        // we slice it here.
        values = list_array->values()->Slice(
            row_splits.front(), row_splits.back() - row_splits.front());
        break;
      }
      default: {
        is_list_array = false;
        break;
      }
    }
    if (!is_list_array) break;
  }

  // Allocate a buffer for the coordinates. A k-nested ListArray will be
  // converted to a sparse tensor of k+1 dimensions. The buffer for the
  // coordinates will contain all the coordinates concatenated, so it needs to
  // hold (k + 1) * num_values numbers.
  const size_t coo_length = nested_row_splits.size();
  const size_t coo_buffer_size =
      coo_length * values->length() * sizeof(int64_t);
  std::shared_ptr<arrow::Buffer> coo_buffer;
  TFX_BSL_ASSIGN_OR_RETURN_ARROW(
      coo_buffer,
      arrow::AllocateBuffer(coo_buffer_size, arrow::default_memory_pool()));
  int64_t* coo_flat = reinterpret_cast<int64_t*>(coo_buffer->mutable_data());

  // COO for the `values`[i] is [x, ..., y, z] if `values`[i] is the z-th
  // element in its belonging sub-list, which is the y-th element in its
  // belonging sub-list,... which is the x-th element in its belonging sub-list,
  // which is the only element in the outermost, dummy "ListArray" denoted by
  // `dummy_outermost_row_splits`.
  //
  // Given i, the the index of an element in the value array of a ListArray,
  // its belonging sub-list is the j-th one, if
  // row_splits[j] <= i < row_splits[j + 1]. And i - row_splits[j] is that
  // element's position in the sub-list, thus the coordinate.

  // This vector stores the indices of the sub-lists at each level the last
  // leaf value belongs to.
  // Note that elements in this vector is non-decreasing as we go through the
  // values in order. That's why it persists outside the loop and gets updated
  // by lookup_and_update().
  std::vector<size_t> current_owning_sublist_indices(nested_row_splits.size(),
                                                     0);
  for (size_t i = 0; i < values->length(); ++i) {
    int64_t* current_coo = coo_flat + i * coo_length;
    int64_t current_idx = i;
    // The inner loop looks for the index in the belonging sub-list at each
    // level.
    for (int j = nested_row_splits.size() - 1; j >= 0; --j) {
      const int64_t row_split_begin = nested_row_splits[j].LookupAndUpdate(
          current_idx, &current_owning_sublist_indices[j]);
      current_coo[j] = current_idx - row_split_begin;
      current_idx = current_owning_sublist_indices[j];
    }
  }

  // The dense shape is the bounding box of the ListArray: the maximum lengths
  // of sub-lists in each level.
  arrow::Int64Builder dense_shape_builder;
  TFX_BSL_RETURN_IF_ERROR(
      FromArrowStatus(dense_shape_builder.Reserve(coo_length)));
  for (const auto& row_splits : nested_row_splits) {
    dense_shape_builder.UnsafeAppend(row_splits.MaxLength());
  }

  TFX_BSL_RETURN_IF_ERROR(
      FromArrowStatus(dense_shape_builder.Finish(dense_shape_array)));

  *coo_array = std::make_shared<arrow::Int64Array>(
      coo_length * values->length(), coo_buffer);
  return absl::OkStatus();
}