in tfx/dsl/placeholder/proto_placeholder.py [0:0]
def _shrink_descriptors(self, fds: descriptor_pb2.FileDescriptorSet) -> None:
"""Deletes all field/message descriptors not used by this placeholder."""
# We don't want to shrink any of the "well-known" proto types (like Any),
# because because the proto runtime verifies that the descriptor for these
# well-known types matches what it expects. The runtimes do this because
# they then replace the message classes with more specific, native classes,
# to offer APIs like `Any.Pack()`, for instance.
well_known_types_pkg = 'google.protobuf.'
# Step 1: Go over all the message descriptors a first time, including
# recursion into nested declarations. Delete field declarations we
# don't need. Collect target types we need because they're the value
# type of a field we want to keep.
def _shrink_message(
name_prefix: str, message_descriptor: descriptor_pb2.DescriptorProto
) -> None:
msg_name = f'{name_prefix}.{message_descriptor.name}'
if not msg_name.startswith(well_known_types_pkg):
# Mark map<> entry key/value fields as used if the map field is used.
if (
message_descriptor.options.map_entry
and msg_name in self._keep_types
):
self._keep_fields.update({f'{msg_name}.key', f'{msg_name}.value'})
# Delete unused fields.
del message_descriptor.extension[:] # We don't support extension fields
_remove_unless(
message_descriptor.field,
lambda f: f'{msg_name}.{f.name}' in self._keep_fields,
)
# Clean up oneofs that have no fields left.
i = 0
while i < len(message_descriptor.oneof_decl):
if all(
not f.HasField('oneof_index') or f.oneof_index != i
for f in message_descriptor.field
):
# No references left. Delete this one and shift all indices down.
del message_descriptor.oneof_decl[i]
for f in message_descriptor.field:
if f.oneof_index > i:
f.oneof_index -= 1
else:
i += 1
# Mark target types of fields as used.
for field_descriptor in message_descriptor.field:
if (
field_descriptor.type
in (
descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE,
descriptor_pb2.FieldDescriptorProto.TYPE_ENUM,
)
and f'{msg_name}.{field_descriptor.name}' in self._keep_fields
):
assert field_descriptor.type_name.startswith('.')
self._keep_types.add(field_descriptor.type_name.removeprefix('.'))
# Recurse into nested message types.
for nested_descriptor in message_descriptor.nested_type:
_shrink_message(msg_name, nested_descriptor)
# Outer invocation of step 1 on all files.
for file_descriptor in fds.file:
del file_descriptor.service[:] # We never need RPC services.
del file_descriptor.extension[:] # We don't support extension fields.
for message_descriptor in file_descriptor.message_type:
_shrink_message(file_descriptor.package, message_descriptor)
# Step 2: Go over all message descriptors a second time, including recursion
# into nested declarations. Delete any nested declarations that were
# not marked in the first pass. Mark any messages that have nested
# declarations, because runtime descriptor pools require the parent
# message to be present (even if unused) before allowing to add
# nested message.
# (This step is actually called within step 3.)
def _purge_types(
name_prefix: str, message_descriptor: descriptor_pb2.DescriptorProto
) -> None:
msg_name = f'{name_prefix}.{message_descriptor.name}'
for nested_descriptor in message_descriptor.nested_type:
_purge_types(msg_name, nested_descriptor)
_remove_unless(
message_descriptor.nested_type,
lambda n: f'{msg_name}.{n.name}' in self._keep_types,
)
_remove_unless(
message_descriptor.enum_type,
lambda e: f'{msg_name}.{e.name}' in self._keep_types,
)
if message_descriptor.nested_type or message_descriptor.enum_type:
self._keep_types.add(msg_name)
# Step 3: Remove the unused messages and enums from the file descriptors.
for file_descriptor in fds.file:
name_prefix = file_descriptor.package
for message_descriptor in file_descriptor.message_type:
_purge_types(name_prefix, message_descriptor) # Step 2
_remove_unless(
file_descriptor.message_type,
lambda m: f'{name_prefix}.{m.name}' in self._keep_types, # pylint: disable=cell-var-from-loop
)
_remove_unless(
file_descriptor.enum_type,
lambda e: f'{name_prefix}.{e.name}' in self._keep_types, # pylint: disable=cell-var-from-loop
)
# Step 4: Remove file descriptors that became empty. Remove declared
# dependencies on other .proto files if those files were removed themselves.
_remove_unless(fds.file, lambda fd: fd.message_type or fd.enum_type)
keep_file_names = {fd.name for fd in fds.file}
for fd in fds.file:
_remove_unless(fd.dependency, lambda dep: dep in keep_file_names)
del fd.public_dependency[:]
del fd.weak_dependency[:]