inline bool TransposeEncoder::WriteBuffers()

in riegeli/chunk_encoding/transpose_encoder.cc [482:610]


inline bool TransposeEncoder::WriteBuffers(
    Writer& header_writer, Writer& data_writer,
    absl::flat_hash_map<NodeId, uint32_t>* buffer_pos) {
  size_t num_buffers = 0;
  for (std::vector<BufferWithMetadata>& buffers : data_) {
    // Sort buffers by length, smallest to largest.
    std::sort(
        buffers.begin(), buffers.end(),
        [](const BufferWithMetadata& a, const BufferWithMetadata& b) {
          if (a.buffer->size() != b.buffer->size()) {
            return a.buffer->size() < b.buffer->size();
          }
          if (a.node_id.parent_message_id != b.node_id.parent_message_id) {
            return a.node_id.parent_message_id < b.node_id.parent_message_id;
          }
          return a.node_id.tag < b.node_id.tag;
        });
    num_buffers += buffers.size();
  }
  const Chain& nonproto_lengths = nonproto_lengths_writer_.dest();
  if (!nonproto_lengths.empty()) ++num_buffers;

  std::vector<size_t> compressed_bucket_sizes;
  std::vector<size_t> buffer_sizes;
  buffer_sizes.reserve(num_buffers);

  chunk_encoding_internal::Compressor bucket_compressor(compressor_options_);
  for (const std::vector<BufferWithMetadata>& buffers : data_) {
    // Split data into buckets.
    size_t remaining_buffers_size = 0;
    for (const BufferWithMetadata& buffer : buffers) {
      remaining_buffers_size += buffer.buffer->size();
    }

    std::vector<size_t> uncompressed_bucket_sizes;
    size_t current_bucket_size = 0;
    for (std::vector<BufferWithMetadata>::const_reverse_iterator iter =
             buffers.crbegin();
         iter != buffers.crend(); ++iter) {
      const size_t current_buffer_size = iter->buffer->size();
      if (current_bucket_size > 0 &&
          current_bucket_size + current_buffer_size / 2 >= bucket_size_) {
        uncompressed_bucket_sizes.push_back(current_bucket_size);
        current_bucket_size = 0;
      }
      current_bucket_size += current_buffer_size;
      remaining_buffers_size -= current_buffer_size;
      if (remaining_buffers_size <= bucket_size_ / 2) {
        current_bucket_size += remaining_buffers_size;
        break;
      }
    }
    if (current_bucket_size > 0) {
      uncompressed_bucket_sizes.push_back(current_bucket_size);
    }

    current_bucket_size = 0;
    for (const BufferWithMetadata& buffer : buffers) {
      absl::optional<size_t> new_uncompressed_bucket_size;
      if (current_bucket_size == 0) {
        RIEGELI_ASSERT(!uncompressed_bucket_sizes.empty())
            << "Bucket sizes and buffer sizes do not match";
        current_bucket_size = uncompressed_bucket_sizes.back();
        uncompressed_bucket_sizes.pop_back();
        new_uncompressed_bucket_size = current_bucket_size;
      }
      RIEGELI_ASSERT_GE(current_bucket_size, buffer.buffer->size())
          << "Bucket sizes and buffer sizes do not match";
      current_bucket_size -= buffer.buffer->size();
      if (ABSL_PREDICT_FALSE(!AddBuffer(
              new_uncompressed_bucket_size, *buffer.buffer, bucket_compressor,
              data_writer, compressed_bucket_sizes, buffer_sizes))) {
        return false;
      }
      const std::pair<absl::flat_hash_map<NodeId, uint32_t>::iterator, bool>
          insert_result = buffer_pos->emplace(
              buffer.node_id, IntCast<uint32_t>(buffer_pos->size()));
      RIEGELI_ASSERT(insert_result.second)
          << "Field already has buffer assigned: "
          << static_cast<uint32_t>(buffer.node_id.parent_message_id) << "/"
          << buffer.node_id.tag;
    }
    RIEGELI_ASSERT(uncompressed_bucket_sizes.empty())
        << "Bucket sizes and buffer sizes do not match";
    RIEGELI_ASSERT_EQ(current_bucket_size, 0u)
        << "Bucket sizes and buffer sizes do not match";
  }
  if (!nonproto_lengths.empty()) {
    // `nonproto_lengths` is the last buffer if non-empty.
    if (ABSL_PREDICT_FALSE(!AddBuffer(nonproto_lengths.size(), nonproto_lengths,
                                      bucket_compressor, data_writer,
                                      compressed_bucket_sizes, buffer_sizes))) {
      return false;
    }
    // Note: `nonproto_lengths` needs no `buffer_pos`.
  }

  if (bucket_compressor.writer().pos() > 0) {
    // Last bucket.
    const Position pos_before = data_writer.pos();
    if (ABSL_PREDICT_FALSE(!bucket_compressor.EncodeAndClose(data_writer))) {
      return Fail(bucket_compressor.status());
    }
    RIEGELI_ASSERT_GE(data_writer.pos(), pos_before)
        << "Data writer position decreased";
    compressed_bucket_sizes.push_back(
        IntCast<size_t>(data_writer.pos() - pos_before));
  }

  if (ABSL_PREDICT_FALSE(!WriteVarint32(
          IntCast<uint32_t>(compressed_bucket_sizes.size()), header_writer)) ||
      ABSL_PREDICT_FALSE(!WriteVarint32(IntCast<uint32_t>(buffer_sizes.size()),
                                        header_writer))) {
    return Fail(header_writer.status());
  }
  for (const size_t length : compressed_bucket_sizes) {
    if (ABSL_PREDICT_FALSE(
            !WriteVarint64(IntCast<uint64_t>(length), header_writer))) {
      return Fail(header_writer.status());
    }
  }
  for (const size_t length : buffer_sizes) {
    if (ABSL_PREDICT_FALSE(
            !WriteVarint64(IntCast<uint64_t>(length), header_writer))) {
      return Fail(header_writer.status());
    }
  }
  return true;
}