in tfx_bsl/tfxio/tf_sequence_example_record.py [0:0]
def _ProjectTfmdSchema(self, tensor_names: List[Text]) -> schema_pb2.Schema:
"""Projects self._schema by the given tensor names."""
tensor_representations = self.TensorRepresentations()
tensor_names = set(tensor_names)
if not tensor_names.issubset(tensor_representations):
raise ValueError(
"Unable to project {} because they were not in the original "
"TensorRepresentations.".format(tensor_names -
tensor_representations))
used_paths = set()
for tensor_name in tensor_names:
used_paths.update(
tensor_representation_util.GetSourceColumnsFromTensorRepresentation(
tensor_representations[tensor_name]))
result = schema_pb2.Schema()
# Note: We only copy projected features into the new schema because the
# coder, and ArrowSchema() only care about Schema.feature. If they start
# depending on other Schema fields then those fields must also be projected.
for f in self._schema.feature:
p = path.ColumnPath(f.name)
if f.name == _SEQUENCE_COLUMN_NAME:
if f.type != schema_pb2.STRUCT:
raise ValueError(
"Feature {} was expected to be of type STRUCT, but got {}"
.format(f.name, f))
result_sequence_struct = schema_pb2.Feature()
result_sequence_struct.CopyFrom(f)
result_sequence_struct.ClearField("struct_domain")
any_sequence_feature_projected = False
for sf in f.struct_domain.feature:
sequence_feature_path = p.child(sf.name)
if sequence_feature_path in used_paths:
any_sequence_feature_projected = True
result_sequence_struct.struct_domain.feature.add().CopyFrom(sf)
if any_sequence_feature_projected:
result.feature.add().CopyFrom(result_sequence_struct)
elif p in used_paths:
result.feature.add().CopyFrom(f)
tensor_representation_util.SetTensorRepresentationsInSchema(
result,
{k: v for k, v in tensor_representations.items() if k in tensor_names})
return result