Intrusive_ptr Recordio_protobuf_reader::decode()

in src/mlio/recordio_protobuf_reader.cc [265:341]


Intrusive_ptr<Example> Recordio_protobuf_reader::decode(const Instance_batch &batch) const
{
    Decoder_state state{*this, batch.size()};

    std::size_t num_instances = batch.instances().size();

    constexpr std::size_t cut_off = 10'000'000;

    bool should_run_serial =
        // If we have any sparse features, we cannot decode the example
        // in parallel as we need to append each instance sequentially
        // to the COO tensor.
        has_sparse_feature_ ||
        // If bad example handling mode is pad, we cannot parallelize
        // decoding as good records must be stacked together without
        // any gap in between.
        params().bad_example_handling == Bad_example_handling::pad ||
        params().bad_example_handling == Bad_example_handling::pad_warn ||
        // If the number of values (e.g. integers, floating-points) we
        // need to decode is below the cut-off threshold, avoid parallel
        // execution; otherwise the threading overhead will potentially
        // slow down the performance.
        num_values_per_instance_ * num_instances < cut_off;

    std::optional<std::size_t> num_instances_read{};
    if (should_run_serial) {
        num_instances_read = decode_serial(state, batch);
    }
    else {
        num_instances_read = decode_parallel(state, batch);
    }

    // Check if we failed to decode the example and return a null pointer
    // if that is the case.
    if (num_instances_read == std::nullopt) {
        if (params().bad_example_handling == Bad_example_handling::skip_warn) {
            logger::warn("The example #{0:n} has been skipped as it had at least one bad instance.",
                         batch.index());
        }

        return nullptr;
    }

    if (num_instances != *num_instances_read) {
        if (params().bad_example_handling == Bad_example_handling::pad_warn) {
            logger::warn("The example #{0:n} has been padded as it had {1:n} bad instance(s).",
                         batch.index(),
                         num_instances - *num_instances_read);
        }
    }

    auto tsr_beg = state.tensors.begin();
    auto tsr_end = state.tensors.end();

    auto bld_beg = state.coo_tensor_builders.begin();
    auto bld_end = state.coo_tensor_builders.end();

    auto ftr_beg = tbb::make_zip_iterator(tsr_beg, bld_beg);
    auto ftr_end = tbb::make_zip_iterator(tsr_end, bld_end);

    for (auto ftr_pos = ftr_beg; ftr_pos < ftr_end; ++ftr_pos) {
        Intrusive_ptr<Tensor> &tensor = std::get<0>(*ftr_pos);

        // If no tensor exists at the specified index, it means the
        // corresponding feature is sparse and we should build its
        // COO tensor.
        if (tensor == nullptr) {
            tensor = std::get<1>(*ftr_pos)->build();
        }
    }

    auto example = make_intrusive<Example>(schema(), std::move(state.tensors));

    example->padding = batch.size() - *num_instances_read;

    return example;
}