in deepracer_systems_pkg/deepracer_systems_pkg/model_loader_module/model_loader_node.py [0:0]
def __init__(self):
"""Create a ModelLoaderNode.
"""
super().__init__("model_loader_node")
self.get_logger().info("model_loader_node started")
self.models_in_progress = dict()
# Threading lock object to safely perform the model operations.
self.progress_guard = threading.Lock()
# Scheduler to queue the function calls and run them in a separate thread.
self.scheduler = scheduler.Scheduler(self.get_logger())
# Flag to enable deleting all existing models in /opt/aws/deepracer/artifacts folder
# before copying the models from USB.
self.enable_model_wipe = model_loader_config.ENABLE_MODEL_WIPE
# Flag to enable model optimizer while transferring the models. Default set to False.
if model_loader_config.ENABLE_MODEL_OPTIMIZER:
self.model_optimizer_client = self.create_client(ModelOptimizeSrv,
model_loader_config.MODEL_OPTIMIZER_SERVER_SERVICE)
else:
self.model_optimizer_client = None
# Supported file extensions and their corresponding action functions.
self.supported_exts = {
".pb": self.copymodel,
".json": self.copymodel,
".gz": self.unzip,
".tar": self.untar
}
# Supported model extension.
self.model_file_extensions = (".pb")
# Service that is called when a model is loaded to verify if the model
# was extracted successfully.
self.verify_model_ready_cb_group = ReentrantCallbackGroup()
self.verify_model_ready_service = self.create_service(VerifyModelReadySrv,
model_loader_config.VERIFY_MODEL_READY_SERVICE_NAME,
self.verify_model_ready_cb,
callback_group=self.verify_model_ready_cb_group)
# A service that is called to extract a tar.gz file with model uploaded from the console
# or delete a model selected.
self.console_model_action_cb_group = ReentrantCallbackGroup()
self.console_model_action_service = self.create_service(ConsoleModelActionSrv,
model_loader_config.CONSOLE_MODEL_ACTION_SERVICE_NAME,
self.console_model_action_cb,
callback_group=self.console_model_action_cb_group)
# Clients to Status LED services that are called to indicate progress/success/failure
# status while loading model.
self.led_cb_group = MutuallyExclusiveCallbackGroup()
self.led_blink_client = self.create_client(SetStatusLedBlinkSrv,
constants.LED_BLINK_SERVICE_NAME,
callback_group=self.led_cb_group)
self.led_solid_client = self.create_client(SetStatusLedSolidSrv,
constants.LED_SOLID_SERVICE_NAME,
callback_group=self.led_cb_group)
while not self.led_blink_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info("Led blink service not available, waiting again...")
while not self.led_solid_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info("Led solid service not available, waiting again...")
self.led_blink_request = SetStatusLedBlinkSrv.Request()
self.led_solid_request = SetStatusLedSolidSrv.Request()
# Client to USB File system subscription service that allows the node to add the "models"
# folder to the watchlist. The usb_monitor_node will trigger notification if it finds
# the files/folders from the watchlist in the USB drive.
self.usb_sub_cb_group = ReentrantCallbackGroup()
self.usb_file_system_subscribe_client = self.create_client(USBFileSystemSubscribeSrv,
constants.USB_FILE_SYSTEM_SUBSCRIBE_SERVICE_NAME,
callback_group=self.usb_sub_cb_group)
while not self.usb_file_system_subscribe_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info("File System Subscribe not available, waiting again...")
# Client to USB Mount point manager service to indicate that the usb_monitor_node can safely
# decrement the counter for the mount point once the action function for the file/folder being
# watched by model_loader_node is succesfully executed.
self.usb_mpm_cb_group = ReentrantCallbackGroup()
self.usb_mount_point_manager_client = self.create_client(USBMountPointManagerSrv,
constants.USB_MOUNT_POINT_MANAGER_SERVICE_NAME,
callback_group=self.usb_mpm_cb_group)
while not self.usb_mount_point_manager_client.wait_for_service(timeout_sec=1.0):
self.get_logger().info("USB mount point manager service not available, waiting again...")
# Subscriber to USB File system notification publisher to recieve the broadcasted messages
# with file/folder details, whenever a watched file is identified in the USB connected.
self.usb_notif_cb_group = ReentrantCallbackGroup()
self.usb_file_system_notification_sub = self.create_subscription(USBFileSystemNotificationMsg,
constants.USB_FILE_SYSTEM_NOTIFICATION_TOPIC,
self.usb_file_system_notification_cb,
10,
callback_group=self.usb_notif_cb_group)
# Add the "models" folder to the watchlist.
usb_file_system_subscribe_request = USBFileSystemSubscribeSrv.Request()
usb_file_system_subscribe_request.file_name = model_loader_config.MODEL_SOURCE_LEAF_DIRECTORY
usb_file_system_subscribe_request.callback_name = model_loader_config.SCHEDULE_MODEL_LOADER_CB
usb_file_system_subscribe_request.verify_name_exists = True
self.usb_file_system_subscribe_client.call_async(usb_file_system_subscribe_request)
# Heartbeat timer.
self.timer_count = 0
self.timer = self.create_timer(5.0, self.timer_callback)
self.get_logger().info("Model Loader node successfully created")