void CompileNeoModel()

in sagemaker-neo-notebooks/edge/cpp-integration/tutorial.cc [344:425]


void CompileNeoModel(const std::string& bucket_name, const std::string& model_name,
                     const std::string& target) {
  // set input parameters
  std::string input_s3 = "s3://" + bucket_name + "/" + model_name;

  const Aws::SageMaker::Model::Framework framework = Aws::SageMaker::Model::Framework::MXNET;
  const Aws::String aws_s3_uri(input_s3.c_str(), input_s3.size());
  const Aws::String data_shape = "{'data':[1,3,224,224]}";
  const Aws::String aws_role_name(ROLE_NAME.c_str(), ROLE_NAME.size());

  std::string job_name = getJobName();
  const Aws::String aws_job_name(job_name.c_str(), job_name.size());

  // set input config
  Aws::SageMaker::SageMakerClient sm_client = getSageMakerClient();
  Aws::SageMaker::Model::InputConfig input_config;
  input_config.SetS3Uri(aws_s3_uri);
  input_config.SetFramework(framework);
  input_config.SetDataInputConfig(data_shape);

  // set output config parameters
  std::string output_s3 = "s3://" + bucket_name + "/output";
  Aws::String aws_output_s3(output_s3.c_str(), output_s3.size());

  // set target device
  Aws::String aws_target_device(target.c_str(), target.size());
  Aws::SageMaker::Model::TargetDevice target_device =
      Aws::SageMaker::Model::TargetDeviceMapper::GetTargetDeviceForName(aws_target_device);

  // set output config
  Aws::SageMaker::Model::OutputConfig output_config;
  output_config.SetS3OutputLocation(aws_output_s3);
  output_config.SetTargetDevice(target_device);

  // set stopping condition
  int max_runtime_in_sec = 900;
  Aws::SageMaker::Model::StoppingCondition stopping_condition;
  stopping_condition.SetMaxRuntimeInSeconds(max_runtime_in_sec);

  // get iam role
  Aws::IAM::Model::Role role;
  getIamRole(ROLE_NAME, role);
  if (role.GetArn().empty()) {
    throw std::runtime_error("Role doesn't exist");
  }

  // create Neo compilation job
  Aws::SageMaker::Model::CreateCompilationJobRequest create_job_request;
  create_job_request.SetCompilationJobName(aws_job_name);
  create_job_request.SetRoleArn(role.GetArn());
  create_job_request.SetInputConfig(input_config);
  create_job_request.SetOutputConfig(output_config);
  create_job_request.SetStoppingCondition(stopping_condition);

  auto compilation_resp = sm_client.CreateCompilationJob(create_job_request);
  if (!compilation_resp.IsSuccess()) {
    auto error = compilation_resp.GetError();
    std::string error_str = getErrorMessage(error);
    throw std::runtime_error("CreateCompilationJob error: " + error_str);
  }

  // poll job for validation
  bool is_success = false;
  int attempts = 10;
  for (int i = 0; i < attempts; i++) {
    std::cout << "Waiting for compilation..." << std::endl;
    auto status = poll_job_status(job_name);
    if (status == Aws::SageMaker::Model::CompilationJobStatus::COMPLETED) {
      std::cout << "Compile successfully" << std::endl;
      is_success = true;
      break;
    } else if (status == Aws::SageMaker::Model::CompilationJobStatus::FAILED) {
      throw std::runtime_error("Compilation fail");
    }
    std::this_thread::sleep_for(std::chrono::seconds(30));
  }

  if (!is_success) {
    throw std::runtime_error("Compilation timeout");
  }
  std::cout << "Done!" << std::endl;
}