in container/neo_template_mxnet_byom.py [0:0]
def preprocess(self, batch_data):
assert self._batch_size == len(batch_data), \
'Invalid input batch size: expected {} but got {}'.format(self._batch_size,
len(batch_data))
processed_batch_data = []
for k in range(len(batch_data)):
req_body = batch_data[k]
content_type = self._context.get_request_header(k, 'Content-type')
if content_type is None:
content_type = self._context.get_request_header(k, 'Content-Type')
if content_type is None:
raise Exception('Content type could not be deduced')
payload = batch_data[k].get('data')
if payload is None:
payload = batch_data[k].get('body')
if payload is None:
raise Exception('Nonexistent payload')
# For BYOM, any content type is allowed
print('content_type = {}'.format(content_type))
try:
# User is responsible for parsing payload into input(s)
input_values = self.user_module.neo_preprocess(payload, content_type)
except Exception as e:
raise Exception('ClientError: User-defined pre-processing function failed:\n'
+ str(e))
# Validate parsed input(s)
if isinstance(input_values, (np.ndarray, np.generic)):
# Single input
if len(self.input_names) != 1:
raise Exception('ClientError: User-defined pre-processing function returns ' +
'a single input, but the model has multiple inputs.')
input_values = {self.input_names[0]: input_values}
elif isinstance(input_values, dict):
# Multiple inputs
given_names = set(input_values.keys())
expected_names = set(self.input_names)
if given_names != expected_names: # Input name(s) mismatch
given_missing = expected_names - given_names
expected_missing = given_names - expected_names
msg = 'ClientError: Input name(s) mismatch: {0} {1}'
if given_missing:
msg += ('\nExpected ' + ', '.join(str(s) for s in given_missing) + \
' in input data')
if expected_missing:
msg += ('\nThe model does not accept the following inputs: ' + \
', '.join(str(s) for s in expected_missing))
msg = msg.format(given_names, expected_names)
raise Exception(msg)
else:
raise Exception('ClientError: User-defined pre-processing function must return ' +
'either dict type or np.ndarray')
processed_batch_data.append(input_values)
return processed_batch_data