in sagemaker-neo-notebooks/edge/cpp-integration/tutorial.cc [344:425]
void CompileNeoModel(const std::string& bucket_name, const std::string& model_name,
const std::string& target) {
// set input parameters
std::string input_s3 = "s3://" + bucket_name + "/" + model_name;
const Aws::SageMaker::Model::Framework framework = Aws::SageMaker::Model::Framework::MXNET;
const Aws::String aws_s3_uri(input_s3.c_str(), input_s3.size());
const Aws::String data_shape = "{'data':[1,3,224,224]}";
const Aws::String aws_role_name(ROLE_NAME.c_str(), ROLE_NAME.size());
std::string job_name = getJobName();
const Aws::String aws_job_name(job_name.c_str(), job_name.size());
// set input config
Aws::SageMaker::SageMakerClient sm_client = getSageMakerClient();
Aws::SageMaker::Model::InputConfig input_config;
input_config.SetS3Uri(aws_s3_uri);
input_config.SetFramework(framework);
input_config.SetDataInputConfig(data_shape);
// set output config parameters
std::string output_s3 = "s3://" + bucket_name + "/output";
Aws::String aws_output_s3(output_s3.c_str(), output_s3.size());
// set target device
Aws::String aws_target_device(target.c_str(), target.size());
Aws::SageMaker::Model::TargetDevice target_device =
Aws::SageMaker::Model::TargetDeviceMapper::GetTargetDeviceForName(aws_target_device);
// set output config
Aws::SageMaker::Model::OutputConfig output_config;
output_config.SetS3OutputLocation(aws_output_s3);
output_config.SetTargetDevice(target_device);
// set stopping condition
int max_runtime_in_sec = 900;
Aws::SageMaker::Model::StoppingCondition stopping_condition;
stopping_condition.SetMaxRuntimeInSeconds(max_runtime_in_sec);
// get iam role
Aws::IAM::Model::Role role;
getIamRole(ROLE_NAME, role);
if (role.GetArn().empty()) {
throw std::runtime_error("Role doesn't exist");
}
// create Neo compilation job
Aws::SageMaker::Model::CreateCompilationJobRequest create_job_request;
create_job_request.SetCompilationJobName(aws_job_name);
create_job_request.SetRoleArn(role.GetArn());
create_job_request.SetInputConfig(input_config);
create_job_request.SetOutputConfig(output_config);
create_job_request.SetStoppingCondition(stopping_condition);
auto compilation_resp = sm_client.CreateCompilationJob(create_job_request);
if (!compilation_resp.IsSuccess()) {
auto error = compilation_resp.GetError();
std::string error_str = getErrorMessage(error);
throw std::runtime_error("CreateCompilationJob error: " + error_str);
}
// poll job for validation
bool is_success = false;
int attempts = 10;
for (int i = 0; i < attempts; i++) {
std::cout << "Waiting for compilation..." << std::endl;
auto status = poll_job_status(job_name);
if (status == Aws::SageMaker::Model::CompilationJobStatus::COMPLETED) {
std::cout << "Compile successfully" << std::endl;
is_success = true;
break;
} else if (status == Aws::SageMaker::Model::CompilationJobStatus::FAILED) {
throw std::runtime_error("Compilation fail");
}
std::this_thread::sleep_for(std::chrono::seconds(30));
}
if (!is_success) {
throw std::runtime_error("Compilation timeout");
}
std::cout << "Done!" << std::endl;
}