def _ProjectTfmdSchema()

in tfx_bsl/tfxio/tf_sequence_example_record.py [0:0]


  def _ProjectTfmdSchema(self, tensor_names: List[Text]) -> schema_pb2.Schema:
    """Projects self._schema by the given tensor names."""
    tensor_representations = self.TensorRepresentations()
    tensor_names = set(tensor_names)
    if not tensor_names.issubset(tensor_representations):
      raise ValueError(
          "Unable to project {} because they were not in the original "
          "TensorRepresentations.".format(tensor_names -
                                          tensor_representations))
    used_paths = set()
    for tensor_name in tensor_names:
      used_paths.update(
          tensor_representation_util.GetSourceColumnsFromTensorRepresentation(
              tensor_representations[tensor_name]))
    result = schema_pb2.Schema()
    # Note: We only copy projected features into the new schema because the
    # coder, and ArrowSchema() only care about Schema.feature. If they start
    # depending on other Schema fields then those fields must also be projected.
    for f in self._schema.feature:
      p = path.ColumnPath(f.name)
      if f.name == _SEQUENCE_COLUMN_NAME:
        if f.type != schema_pb2.STRUCT:
          raise ValueError(
              "Feature {} was expected to be of type STRUCT, but got {}"
              .format(f.name, f))
        result_sequence_struct = schema_pb2.Feature()
        result_sequence_struct.CopyFrom(f)
        result_sequence_struct.ClearField("struct_domain")
        any_sequence_feature_projected = False
        for sf in f.struct_domain.feature:
          sequence_feature_path = p.child(sf.name)
          if sequence_feature_path in used_paths:
            any_sequence_feature_projected = True
            result_sequence_struct.struct_domain.feature.add().CopyFrom(sf)
        if any_sequence_feature_projected:
          result.feature.add().CopyFrom(result_sequence_struct)
      elif p in used_paths:
        result.feature.add().CopyFrom(f)

    tensor_representation_util.SetTensorRepresentationsInSchema(
        result,
        {k: v for k, v in tensor_representations.items() if k in tensor_names})

    return result