absl::Status SequenceExamplesToRecordBatchDecoder::Make()

in tfx_bsl/cc/coders/example_decoder.cc [870:961]


absl::Status SequenceExamplesToRecordBatchDecoder::Make(
    const absl::optional<absl::string_view>& serialized_schema,
    const std::string& sequence_feature_column_name,
    std::unique_ptr<SequenceExamplesToRecordBatchDecoder>* result) {
  if (!serialized_schema) {
    *result = absl::WrapUnique(new SequenceExamplesToRecordBatchDecoder(
        sequence_feature_column_name, nullptr, nullptr, nullptr, nullptr));
    return absl::OkStatus();
  }
  auto context_feature_decoders = absl::make_unique<
      absl::flat_hash_map<std::string, std::unique_ptr<FeatureDecoder>>>();
  auto sequence_feature_decoders = absl::make_unique<
      absl::flat_hash_map<std::string, std::unique_ptr<FeatureListDecoder>>>();
  auto schema = absl::make_unique<tensorflow::metadata::v0::Schema>();
  if (!schema->ParseFromArray(serialized_schema->data(),
                              serialized_schema->size())) {
    return absl::InvalidArgumentError("Unable to parse schema.");
  }
  std::vector<std::shared_ptr<arrow::Field>> arrow_schema_fields;
  auto sequence_feature_schema_fields =
      absl::make_unique<std::vector<std::shared_ptr<arrow::Field>>>();
  for (const tensorflow::metadata::v0::Feature& feature : schema->feature()) {
    if (feature.name() == sequence_feature_column_name) {
      // This feature is a top-level feature containing sequence features, as
      // identified by the sequence_feature_column_name.
      if (feature.type() != tensorflow::metadata::v0::STRUCT) {
        return absl::InvalidArgumentError(absl::StrCat(
            "Found a feature in the schema with the "
            "sequence_feature_column_name (i.e., ",
            sequence_feature_column_name,
            ") that is not a struct. The sequence_feature_column_name should "
            "be used only for the top-level struct feature with a struct "
            "domain that contains each sequence feature as a child."));
      }
      for (const auto& child_feature : feature.struct_domain().feature()) {
        if (sequence_feature_decoders->find(child_feature.name()) !=
            sequence_feature_decoders->end()) {
          // TODO(b/160886325): duplicated features in the (same environment) in
          // the schema should be a bug, but before TFDV checks for it, we
          // tolerate it.
          // TODO(b/160885730): the coder is current not environment aware,
          // which means if there are two features of the same name but
          // belonging to different environments, the first feature will be
          // taken.
          continue;
        }
        TFX_BSL_RETURN_IF_ERROR(MakeFeatureListDecoder(
            child_feature,
            &(*sequence_feature_decoders)[child_feature.name()]));
        sequence_feature_schema_fields->emplace_back();
        TFX_BSL_RETURN_IF_ERROR(TfmdFeatureToArrowField(
            /*is_sequence_feature=*/true, child_feature,
            &sequence_feature_schema_fields->back()));
      }
      continue;
    }
    if (context_feature_decoders->find(feature.name()) !=
        context_feature_decoders->end()) {
      // TODO(b/160886325): duplicated features in the (same environment) in the
      // schema should be a bug, but before TFDV checks for it, we tolerate
      // it.
      // TODO(b/160885730): the coder is current not environment aware, which
      // means if there are two features of the same name but belonging to
      // different environments, the first feature will be taken.
      continue;
    }
    // If the feature is not the top-level sequence feature, it is a context
    // feature.
    TFX_BSL_RETURN_IF_ERROR(MakeFeatureDecoder(
        feature, &(*context_feature_decoders)[feature.name()]));
    arrow_schema_fields.emplace_back();
    TFX_BSL_RETURN_IF_ERROR(TfmdFeatureToArrowField(
        /*is_sequence_feature=*/false, feature, &arrow_schema_fields.back()));
  }
  std::shared_ptr<arrow::StructType> sequence_features_struct_type = nullptr;
  if (!(*sequence_feature_schema_fields).empty()) {
    // Add a single top-level struct field to the arrow schema fields, which
    // contains all of the sequence feature fields.
    sequence_features_struct_type =
        std::make_shared<arrow::StructType>(*sequence_feature_schema_fields);
    arrow_schema_fields.push_back(arrow::field(sequence_feature_column_name,
                                               sequence_features_struct_type));
  }

  *result = absl::WrapUnique(new SequenceExamplesToRecordBatchDecoder(
      sequence_feature_column_name,
      arrow::schema(std::move(arrow_schema_fields)),
      std::move(sequence_features_struct_type),
      std::move(context_feature_decoders),
      std::move(sequence_feature_decoders)));
  return absl::OkStatus();
}