in tfx_bsl/cc/arrow/array_util.cc [691:804]
absl::Status CooFromListArray(
const std::shared_ptr<arrow::Array>& list_array,
std::shared_ptr<arrow::Array>* coo_array,
std::shared_ptr<arrow::Array>* dense_shape_array) {
// A ListArray encodes its list structure using offsets, or "row splits"
// where [row_splits[i], row_splits[i+1]) are the indices of values of
// the i-th sub-list. For example:
// [[a, b], [], [c]] is encoded as:
// value_array: [a, b, c]
// row_splits: [0, 2, 2, 3]
// A k-nested ListArray is encoded recursively as a row_splits array
// and a (k-1)-Nested ListArray (or a primitive array if k==0). A 1-nested
// ListArray is a ListArray<primitive>.
std::vector<RowSplitsRep> nested_row_splits;
std::array<int32_t, 2> dummy_outermost_row_splits = {
0, static_cast<int32_t>(list_array->length())};
nested_row_splits.push_back(
RowSplitsRep(absl::MakeSpan(dummy_outermost_row_splits)));
// Strip `list_array` and populate `nested_row_splits` with row_splits of
// each level.
std::shared_ptr<arrow::Array> values = list_array;
while (true) {
bool is_list_array = true;
switch (values->type()->id()) {
case arrow::Type::LIST: {
ListArray* list_array = static_cast<ListArray*>(values.get());
RowSplitsRep row_splits(*list_array);
nested_row_splits.push_back(row_splits);
// Note that the values array is not sliced even if `list_array` is, so
// we slice it here.
values = list_array->values()->Slice(
row_splits.front(), row_splits.back() - row_splits.front());
break;
}
case arrow::Type::LARGE_LIST: {
LargeListArray* list_array = static_cast<LargeListArray*>(values.get());
RowSplitsRep row_splits(*list_array);
nested_row_splits.push_back(row_splits);
// Note that the values array is not sliced even if `list_array` is, so
// we slice it here.
values = list_array->values()->Slice(
row_splits.front(), row_splits.back() - row_splits.front());
break;
}
default: {
is_list_array = false;
break;
}
}
if (!is_list_array) break;
}
// Allocate a buffer for the coordinates. A k-nested ListArray will be
// converted to a sparse tensor of k+1 dimensions. The buffer for the
// coordinates will contain all the coordinates concatenated, so it needs to
// hold (k + 1) * num_values numbers.
const size_t coo_length = nested_row_splits.size();
const size_t coo_buffer_size =
coo_length * values->length() * sizeof(int64_t);
std::shared_ptr<arrow::Buffer> coo_buffer;
TFX_BSL_ASSIGN_OR_RETURN_ARROW(
coo_buffer,
arrow::AllocateBuffer(coo_buffer_size, arrow::default_memory_pool()));
int64_t* coo_flat = reinterpret_cast<int64_t*>(coo_buffer->mutable_data());
// COO for the `values`[i] is [x, ..., y, z] if `values`[i] is the z-th
// element in its belonging sub-list, which is the y-th element in its
// belonging sub-list,... which is the x-th element in its belonging sub-list,
// which is the only element in the outermost, dummy "ListArray" denoted by
// `dummy_outermost_row_splits`.
//
// Given i, the the index of an element in the value array of a ListArray,
// its belonging sub-list is the j-th one, if
// row_splits[j] <= i < row_splits[j + 1]. And i - row_splits[j] is that
// element's position in the sub-list, thus the coordinate.
// This vector stores the indices of the sub-lists at each level the last
// leaf value belongs to.
// Note that elements in this vector is non-decreasing as we go through the
// values in order. That's why it persists outside the loop and gets updated
// by lookup_and_update().
std::vector<size_t> current_owning_sublist_indices(nested_row_splits.size(),
0);
for (size_t i = 0; i < values->length(); ++i) {
int64_t* current_coo = coo_flat + i * coo_length;
int64_t current_idx = i;
// The inner loop looks for the index in the belonging sub-list at each
// level.
for (int j = nested_row_splits.size() - 1; j >= 0; --j) {
const int64_t row_split_begin = nested_row_splits[j].LookupAndUpdate(
current_idx, ¤t_owning_sublist_indices[j]);
current_coo[j] = current_idx - row_split_begin;
current_idx = current_owning_sublist_indices[j];
}
}
// The dense shape is the bounding box of the ListArray: the maximum lengths
// of sub-lists in each level.
arrow::Int64Builder dense_shape_builder;
TFX_BSL_RETURN_IF_ERROR(
FromArrowStatus(dense_shape_builder.Reserve(coo_length)));
for (const auto& row_splits : nested_row_splits) {
dense_shape_builder.UnsafeAppend(row_splits.MaxLength());
}
TFX_BSL_RETURN_IF_ERROR(
FromArrowStatus(dense_shape_builder.Finish(dense_shape_array)));
*coo_array = std::make_shared<arrow::Int64Array>(
coo_length * values->length(), coo_buffer);
return absl::OkStatus();
}