def _shrink_descriptors()

in tfx/dsl/placeholder/proto_placeholder.py [0:0]


  def _shrink_descriptors(self, fds: descriptor_pb2.FileDescriptorSet) -> None:
    """Deletes all field/message descriptors not used by this placeholder."""
    # We don't want to shrink any of the "well-known" proto types (like Any),
    # because because the proto runtime verifies that the descriptor for these
    # well-known types matches what it expects. The runtimes do this because
    # they then replace the message classes with more specific, native classes,
    # to offer APIs like `Any.Pack()`, for instance.
    well_known_types_pkg = 'google.protobuf.'

    # Step 1: Go over all the message descriptors a first time, including
    #         recursion into nested declarations. Delete field declarations we
    #         don't need. Collect target types we need because they're the value
    #         type of a field we want to keep.
    def _shrink_message(
        name_prefix: str, message_descriptor: descriptor_pb2.DescriptorProto
    ) -> None:
      msg_name = f'{name_prefix}.{message_descriptor.name}'
      if not msg_name.startswith(well_known_types_pkg):
        # Mark map<> entry key/value fields as used if the map field is used.
        if (
            message_descriptor.options.map_entry
            and msg_name in self._keep_types
        ):
          self._keep_fields.update({f'{msg_name}.key', f'{msg_name}.value'})

        # Delete unused fields.
        del message_descriptor.extension[:]  # We don't support extension fields
        _remove_unless(
            message_descriptor.field,
            lambda f: f'{msg_name}.{f.name}' in self._keep_fields,
        )

        # Clean up oneofs that have no fields left.
        i = 0
        while i < len(message_descriptor.oneof_decl):
          if all(
              not f.HasField('oneof_index') or f.oneof_index != i
              for f in message_descriptor.field
          ):
            # No references left. Delete this one and shift all indices down.
            del message_descriptor.oneof_decl[i]
            for f in message_descriptor.field:
              if f.oneof_index > i:
                f.oneof_index -= 1
          else:
            i += 1

      # Mark target types of fields as used.
      for field_descriptor in message_descriptor.field:
        if (
            field_descriptor.type
            in (
                descriptor_pb2.FieldDescriptorProto.TYPE_MESSAGE,
                descriptor_pb2.FieldDescriptorProto.TYPE_ENUM,
            )
            and f'{msg_name}.{field_descriptor.name}' in self._keep_fields
        ):
          assert field_descriptor.type_name.startswith('.')
          self._keep_types.add(field_descriptor.type_name.removeprefix('.'))

      # Recurse into nested message types.
      for nested_descriptor in message_descriptor.nested_type:
        _shrink_message(msg_name, nested_descriptor)

    # Outer invocation of step 1 on all files.
    for file_descriptor in fds.file:
      del file_descriptor.service[:]  # We never need RPC services.
      del file_descriptor.extension[:]  # We don't support extension fields.
      for message_descriptor in file_descriptor.message_type:
        _shrink_message(file_descriptor.package, message_descriptor)

    # Step 2: Go over all message descriptors a second time, including recursion
    #         into nested declarations. Delete any nested declarations that were
    #         not marked in the first pass. Mark any messages that have nested
    #         declarations, because runtime descriptor pools require the parent
    #         message to be present (even if unused) before allowing to add
    #         nested message.
    #         (This step is actually called within step 3.)
    def _purge_types(
        name_prefix: str, message_descriptor: descriptor_pb2.DescriptorProto
    ) -> None:
      msg_name = f'{name_prefix}.{message_descriptor.name}'
      for nested_descriptor in message_descriptor.nested_type:
        _purge_types(msg_name, nested_descriptor)
      _remove_unless(
          message_descriptor.nested_type,
          lambda n: f'{msg_name}.{n.name}' in self._keep_types,
      )
      _remove_unless(
          message_descriptor.enum_type,
          lambda e: f'{msg_name}.{e.name}' in self._keep_types,
      )
      if message_descriptor.nested_type or message_descriptor.enum_type:
        self._keep_types.add(msg_name)

    # Step 3: Remove the unused messages and enums from the file descriptors.
    for file_descriptor in fds.file:
      name_prefix = file_descriptor.package
      for message_descriptor in file_descriptor.message_type:
        _purge_types(name_prefix, message_descriptor)  # Step 2
      _remove_unless(
          file_descriptor.message_type,
          lambda m: f'{name_prefix}.{m.name}' in self._keep_types,  # pylint: disable=cell-var-from-loop
      )
      _remove_unless(
          file_descriptor.enum_type,
          lambda e: f'{name_prefix}.{e.name}' in self._keep_types,  # pylint: disable=cell-var-from-loop
      )

    # Step 4: Remove file descriptors that became empty. Remove declared
    # dependencies on other .proto files if those files were removed themselves.
    _remove_unless(fds.file, lambda fd: fd.message_type or fd.enum_type)
    keep_file_names = {fd.name for fd in fds.file}
    for fd in fds.file:
      _remove_unless(fd.dependency, lambda dep: dep in keep_file_names)
      del fd.public_dependency[:]
      del fd.weak_dependency[:]