in tensorflow_model_analysis/eval_saved_model/load.py [0:0]
def predict_list(self,
inputs: MultipleInputFeedType) -> List[FetchedTensorValues]:
"""Like predict, but takes a list of inputs.
Args:
inputs: A list of input data (or a dict of keys to lists of input data).
See predict for more details.
Returns:
A list of FetchedTensorValues. See predict for more details.
Raises:
ValueError: If the original input_refs tensor passed to the
EvalInputReceiver does not align with the features, predictions and
labels returned after feeding the inputs.
"""
if isinstance(inputs, dict):
input_args = []
# Only add values for keys that are in the input map (in order).
for key in self._input_map:
if key in inputs:
input_args.append(inputs[key])
else:
input_args = [inputs]
if self._iterator_initializer_fn:
self._iterator_initializer_fn(*input_args)
input_args = []
result = []
while True:
try:
(features, predictions, labels, input_refs,
additional_fetches) = self._predict_list_fn(*input_args)
all_fetches = additional_fetches
all_fetches[constants.FEATURES_NAME] = features
all_fetches[constants.LABELS_NAME] = labels
all_fetches[constants.PREDICTIONS_NAME] = predictions
# TODO(cyfoo): Optimise this.
split_fetches = {}
for group, tensors in all_fetches.items():
split_tensors = {}
for key in tensors:
if not np.isscalar(tensors[key]):
split_tensors[key] = util.split_tensor_value(tensors[key])
split_fetches[group] = split_tensors
if (not isinstance(input_refs, np.ndarray) or input_refs.ndim != 1 or
not np.issubdtype(input_refs.dtype, np.integer)):
raise ValueError('input_refs should be an 1-D array of integers. '
'input_refs was {}.'.format(input_refs))
for group, tensors in split_fetches.items():
for result_key, split_values in tensors.items():
if len(split_values) != input_refs.shape[0]:
raise ValueError(
'input_refs should be batch-aligned with fetched values; '
'{} key {} had {} slices but input_refs had batch size of '
'{}'.format(group, result_key, len(split_values),
input_refs.shape[0]))
for i, input_ref in enumerate(input_refs):
if input_ref < 0 or input_ref >= len(inputs):
raise ValueError(
'An index in input_refs is out of range: {} vs {}; '
'inputs: {}'.format(input_ref, len(inputs), inputs))
values = {}
for group, split_tensors in split_fetches.items():
tensor_values = {}
for key, split_value in split_tensors.items():
tensor_values[key] = split_value[i]
values[group] = util.extract_tensor_maybe_dict(group, tensor_values)
result.append(FetchedTensorValues(input_ref=input_ref, values=values))
if self._iterator_initializer_fn is None:
break
except tf.errors.OutOfRangeError:
break
return result