def predict_list()

in tensorflow_model_analysis/eval_saved_model/load.py [0:0]


  def predict_list(self,
                   inputs: MultipleInputFeedType) -> List[FetchedTensorValues]:
    """Like predict, but takes a list of inputs.

    Args:
      inputs: A list of input data (or a dict of keys to lists of input data).
        See predict for more details.

    Returns:
       A list of FetchedTensorValues. See predict for more details.

    Raises:
      ValueError: If the original input_refs tensor passed to the
        EvalInputReceiver does not align with the features, predictions and
        labels returned after feeding the inputs.
    """
    if isinstance(inputs, dict):
      input_args = []
      # Only add values for keys that are in the input map (in order).
      for key in self._input_map:
        if key in inputs:
          input_args.append(inputs[key])
    else:
      input_args = [inputs]

    if self._iterator_initializer_fn:
      self._iterator_initializer_fn(*input_args)
      input_args = []

    result = []

    while True:
      try:
        (features, predictions, labels, input_refs,
         additional_fetches) = self._predict_list_fn(*input_args)

        all_fetches = additional_fetches
        all_fetches[constants.FEATURES_NAME] = features
        all_fetches[constants.LABELS_NAME] = labels
        all_fetches[constants.PREDICTIONS_NAME] = predictions

        # TODO(cyfoo): Optimise this.
        split_fetches = {}
        for group, tensors in all_fetches.items():
          split_tensors = {}
          for key in tensors:
            if not np.isscalar(tensors[key]):
              split_tensors[key] = util.split_tensor_value(tensors[key])
          split_fetches[group] = split_tensors

        if (not isinstance(input_refs, np.ndarray) or input_refs.ndim != 1 or
            not np.issubdtype(input_refs.dtype, np.integer)):
          raise ValueError('input_refs should be an 1-D array of integers. '
                           'input_refs was {}.'.format(input_refs))

        for group, tensors in split_fetches.items():
          for result_key, split_values in tensors.items():
            if len(split_values) != input_refs.shape[0]:
              raise ValueError(
                  'input_refs should be batch-aligned with fetched values; '
                  '{} key {} had {} slices but input_refs had batch size of '
                  '{}'.format(group, result_key, len(split_values),
                              input_refs.shape[0]))

        for i, input_ref in enumerate(input_refs):
          if input_ref < 0 or input_ref >= len(inputs):
            raise ValueError(
                'An index in input_refs is out of range: {} vs {}; '
                'inputs: {}'.format(input_ref, len(inputs), inputs))
          values = {}
          for group, split_tensors in split_fetches.items():
            tensor_values = {}
            for key, split_value in split_tensors.items():
              tensor_values[key] = split_value[i]
            values[group] = util.extract_tensor_maybe_dict(group, tensor_values)

          result.append(FetchedTensorValues(input_ref=input_ref, values=values))

        if self._iterator_initializer_fn is None:
          break
      except tf.errors.OutOfRangeError:
        break

    return result