in source/export/gluoncv_model_export.py [0:0]
def forward_test(self, model_name, image_full_path, iter_times):
"""
inference test for given model
:param model_name: model name which is expected to be test
:param image_full_path: path of test image
:param iter_times: iteration times for evaluating inference time cost
:return:
"""
detector = gluon.nn.SymbolBlock.imports(
symbol_file=os.path.join('{}-symbol.json'.format(model_name)),
input_names=['data'],
param_file=os.path.join('{}-0000.params'.format(model_name)),
ctx=self._ctx
)
image = mx.image.imread(image_full_path)
height, width, channels = image.shape
t_start = time.time()
for _ in range(iter_times):
short_size = self._short_size_mapping[model_name]
normalized_image = self.resize(image=image, short_size=short_size)
if mx.context.num_gpus() != 0:
normalized_image = normalized_image.copyto(mx.gpu())
# inference
mx_ids, mx_scores, mx_bounding_boxes = detector(normalized_image)
# post-process
mx_ids = mx_ids.asnumpy()
mx_scores = mx_scores.asnumpy()
mx_bounding_boxes = mx_bounding_boxes.asnumpy()
# resize detection results back to original image size
scale_ratio = short_size / height if height < width else short_size / width
bbox_coords, bbox_scores, class_ids = list(), list(), list()
for index, bbox in enumerate(mx_bounding_boxes[0]):
prob = float(mx_scores[0][index][0])
if prob < 0.0:
continue
[x_min, y_min, x_max, y_max] = bbox
x_min = int(x_min / scale_ratio)
y_min = int(y_min / scale_ratio)
x_max = int(x_max / scale_ratio)
y_max = int(y_max / scale_ratio)
bbox_coords.append([x_min, y_min, x_max, y_max])
bbox_scores.append([prob])
class_ids.append([mx_ids[0][index][0]])
t_end = time.time()
bbox_coords = np.array(bbox_coords)
bbox_scores = np.array(bbox_scores)
class_ids = np.array(class_ids)
print('Mode = {}: Inference average time cost = {} seconds'.format(model_name, (t_end - t_start)/iter_times))
image = cv2.imread(image_full_path, cv2.IMREAD_COLOR)
image = image[:, :, ::-1]
ax = utils.viz.plot_bbox(image, bbox_coords, bbox_scores, class_ids, class_names=['Person'], thresh=0.5)
plt.axis('off')
plt.show()