def _check_module_is_text_embedding()

in tensorflow_hub/feature_column.py [0:0]


def _check_module_is_text_embedding(module_spec):
  """Raises ValueError if `module_spec` is not a text-embedding module.

  Args:
    module_spec: A `ModuleSpec` to test.

  Raises:
    ValueError: if `module_spec` default signature is not compatible with
    Tensor(string, shape=(?,)) -> Tensor(float32, shape=(?,K)).
  """
  issues = []

  # Find issues with signature inputs.
  input_info_dict = module_spec.get_input_info_dict()
  if len(input_info_dict) != 1:
    issues.append("Module default signature must require only one input")
  else:
    input_info, = input_info_dict.values()
    input_shape = input_info.get_shape()
    if not (input_info.dtype == tf.string and input_shape.ndims == 1 and
            input_shape.as_list() == [None]):
      issues.append("Module default signature must have only one input "
                    "tf.Tensor(shape=(?,), dtype=string)")

  # Find issues with signature outputs.
  output_info_dict = module_spec.get_output_info_dict()
  if "default" not in output_info_dict:
    issues.append("Module default signature must have a 'default' output.")
  else:
    output_info = output_info_dict["default"]
    output_shape = output_info.get_shape()
    if not (output_info.dtype == tf.float32 and output_shape.ndims == 2 and
            not output_shape.as_list()[0] and output_shape.as_list()[1]):
      issues.append("Module default signature must have a 'default' output of "
                    "tf.Tensor(shape=(?,K), dtype=float32).")

  if issues:
    raise ValueError("Module is not a text-embedding: %r" % issues)