in miscellaneous/distributed_tensorflow_mask_rcnn/container-serving-optimized/resources/predict.py [0:0]
def get_predictor(cls):
"""load trained model"""
with cls.lock:
# check if model is already loaded
if cls.predictor:
return cls.predictor
os.environ["TENSORPACK_FP16"] = "true"
# create a mask r-cnn model
mask_rcnn_model = ResNetFPNModel(True)
try:
model_dir = os.environ["SM_MODEL_DIR"]
except KeyError:
model_dir = "/opt/ml/model"
try:
resnet_arch = os.environ["RESNET_ARCH"]
except KeyError:
resnet_arch = "resnet50"
# file path to previoulsy trained mask r-cnn model
latest_trained_model = ""
model_search_path = os.path.join(model_dir, "model-*.index")
for model_file in glob.glob(model_search_path):
if model_file > latest_trained_model:
latest_trained_model = model_file
trained_model = latest_trained_model
print(f"Using model: {trained_model}")
# fixed resnet50 backbone weights
cfg.MODE_FPN = True
cfg.MODE_MASK = True
if resnet_arch == "resnet101":
cfg.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 23, 3]
else:
cfg.BACKBONE.RESNET_NUM_BLOCKS = [3, 4, 6, 3]
cfg_prefix = "CONFIG__"
for key, value in dict(os.environ).items():
if key.startswith(cfg_prefix):
attr_name = key[len(cfg_prefix) :]
attr_name = attr_name.replace("__", ".")
value = eval(value)
print(f"update config: {attr_name}={value}")
nested_var = cfg
attr_list = attr_name.split(".")
for attr in attr_list[0:-1]:
nested_var = getattr(nested_var, attr)
setattr(nested_var, attr_list[-1], value)
# calling detection dataset gets the number of coco categories
# and saves in the configuration
DetectionDataset()
finalize_configs(is_training=False)
# Create an inference model
# PredictConfig takes a model, input tensors and output tensors
cls.predictor = OfflinePredictor(
PredictConfig(
model=mask_rcnn_model,
session_init=get_model_loader(trained_model),
input_names=["images", "orig_image_dims"],
output_names=[
"generate_{}_proposals_topk_per_image/boxes".format(
"fpn" if cfg.MODE_FPN else "rpn"
),
"generate_{}_proposals_topk_per_image/scores".format(
"fpn" if cfg.MODE_FPN else "rpn"
),
"fastrcnn_all_scores",
"output/boxes",
"output/scores",
"output/labels",
"output/masks",
],
)
)
return cls.predictor