in deepracer_offroad_ws/ctrl_pkg/src/ctrl_state.cpp [135:237]
bool AutoDriveCtrl::loadModelReq(int requestSeqNum, std::string modelName, std::vector<int> modelMetadataSensors,
int trainingAlgorithm, int actionSpaceType, std::string imgFormat,
int width, int height, int numChannels,
int lidarChannels, int platform, int task, int preProcess) {
(void)task;
RCLCPP_INFO(ctrlNode->get_logger(), "New load model request: %d; Latest request: %d", requestSeqNum, latestLoadModelRequestSeq_);
latestLoadModelRequestSeq_ = requestSeqNum;
const std::lock_guard<std::mutex> lock(loadModelMutex_);
try {
modelLoadingStatus_ = "loading";
//! TODO handle ssd models, maybe as a different state and switch
//! to enums when build is consilidated.
int rlTask = 0; // Only load rl models
RCLCPP_INFO(ctrlNode->get_logger(), "[0/5] Starting the load model service calls for req: %d; Latest request: %d", requestSeqNum, latestLoadModelRequestSeq_);
// Stop the inference task
if (!setInferState(rlTask, false)) {
RCLCPP_ERROR(ctrlNode->get_logger(), "Failed to call the inference service");
modelLoadingStatus_ = "error";
return false;
}
RCLCPP_INFO(ctrlNode->get_logger(), "[1/5] Stopped inference task for req: %d; Latest request: %d", requestSeqNum, latestLoadModelRequestSeq_);
if(latestLoadModelRequestSeq_ > requestSeqNum) {
RCLCPP_WARN(ctrlNode->get_logger(), "Another load model request %d is waiting over current request %d..", latestLoadModelRequestSeq_, requestSeqNum);
return false;
}
// Optimize the model.
auto moSrvRequest = std::make_shared<deepracer_interfaces_pkg::srv::ModelOptimizeSrv::Request>();
moSrvRequest->model_name = modelName.c_str();
moSrvRequest->model_metadata_sensors = modelMetadataSensors;
moSrvRequest->training_algorithm = trainingAlgorithm;
moSrvRequest->img_format = imgFormat.c_str();
moSrvRequest->width = width;
moSrvRequest->height = height;
moSrvRequest->num_channels = numChannels;
moSrvRequest->platform = platform;
moSrvRequest->lidar_channels = lidarChannels;
auto future_result_mo = modelOptimizerClient_->async_send_request(moSrvRequest);
auto moSrvResponse = std::make_shared<deepracer_interfaces_pkg::srv::ModelOptimizeSrv::Response>();
future_result_mo.wait();
moSrvResponse = future_result_mo.get();
if(moSrvResponse->error != 0){
RCLCPP_ERROR(ctrlNode->get_logger(), "Model optimizer failed.");
modelLoadingStatus_ = "error";
return false;
}
RCLCPP_INFO(ctrlNode->get_logger(), "[2/5] Optimized the model for req: %d; Latest request: %d", requestSeqNum, latestLoadModelRequestSeq_);
if(latestLoadModelRequestSeq_ > requestSeqNum) {
RCLCPP_WARN(ctrlNode->get_logger(), "Another load model request %d is waiting over current request %d..", latestLoadModelRequestSeq_, requestSeqNum);
return false;
}
// Load the model into the memory.
auto modelSrvRequest = std::make_shared<deepracer_interfaces_pkg::srv::LoadModelSrv::Request>();
modelSrvRequest->artifact_path = moSrvResponse->artifact_path;
modelSrvRequest->task_type = rlTask;
modelSrvRequest->pre_process_type = preProcess;
modelSrvRequest->action_space_type = actionSpaceType;
auto future_result_load_model = loadModelClient_->async_send_request(modelSrvRequest);
future_result_load_model.wait();
auto loadModelSrvResponse = future_result_load_model.get();
if(loadModelSrvResponse->error != 0){
RCLCPP_ERROR(ctrlNode->get_logger(), "Model loader failed.");
modelLoadingStatus_ = "error";
return false;
}
RCLCPP_INFO(ctrlNode->get_logger(), "[3/5] Inference node updated for req: %d; Latest request: %d", requestSeqNum, latestLoadModelRequestSeq_);
if(latestLoadModelRequestSeq_ > requestSeqNum) {
RCLCPP_WARN(ctrlNode->get_logger(), "Another load model request %d is waiting over current request %d..", latestLoadModelRequestSeq_, requestSeqNum);
return false;
}
auto future_result_load_action_space = loadActionSpaceClient_->async_send_request(modelSrvRequest);
future_result_load_action_space.wait();
auto loadActionSpaceResponse = future_result_load_action_space.get();
if(loadActionSpaceResponse->error != 0){
RCLCPP_ERROR(ctrlNode->get_logger(), "Action space load failed.");
modelLoadingStatus_ = "error";
return false;
}
RCLCPP_INFO(ctrlNode->get_logger(), "[4/5] Action space in navigation node updated for req: %d; Latest request: %d", requestSeqNum, latestLoadModelRequestSeq_);
if(latestLoadModelRequestSeq_ > requestSeqNum) {
RCLCPP_WARN(ctrlNode->get_logger(), "Another load model request %d is waiting over current request %d..", latestLoadModelRequestSeq_, requestSeqNum);
return false;
}
// Start inference if the model was running.
if (!setInferState(rlTask, isActive_)) {
RCLCPP_ERROR(ctrlNode->get_logger(), "Failed to reintialize inference");
modelLoadingStatus_ = "error";
return false;
}
RCLCPP_INFO(ctrlNode->get_logger(), "[5/5] Inference task set for req: %d; Latest request: %d", requestSeqNum, latestLoadModelRequestSeq_);
modelLoadingStatus_ = "loaded";
}
catch (const std::exception &ex) {
RCLCPP_ERROR(ctrlNode->get_logger(), "Model failed to load: %s", ex.what());
modelLoadingStatus_ = "error";
return false;
}
return true;
}