pplx::task RemoteExecutor::StartTask()

in nodemanager/core/RemoteExecutor.cpp [139:282]


pplx::task<json::value> RemoteExecutor::StartTask(StartTaskArgs&& args, std::string&& callbackUri)
{
    WriterLock writerLock(&this->lock);

    bool isNewEntry;
    std::shared_ptr<TaskInfo> taskInfo = this->jobTaskTable.AddJobAndTask(args.JobId, args.TaskId, isNewEntry);

    taskInfo->Affinity = args.StartInfo.Affinity;
    taskInfo->SetTaskRequeueCount(args.StartInfo.TaskRequeueCount);

    std::string userName = "root";
    auto jobUser = this->jobUsers.find(args.JobId);
    if (jobUser == this->jobUsers.end())
    {
        this->jobTaskTable.RemoveJob(args.JobId);
        throw std::runtime_error(String::Join(" ", "Job", args.JobId, "was not started on this node."));
    }
    else
    {
        userName = std::get<0>(jobUser->second);
    }

    if (args.StartInfo.CommandLine.empty())
    {
        Logger::Info(args.JobId, args.TaskId, args.StartInfo.TaskRequeueCount, "MPI non-master task found, skip creating the process.");
        std::string dockerImage = args.StartInfo.EnvironmentVariables["CCP_DOCKER_IMAGE"];
        std::string isNvidiaDocker = args.StartInfo.EnvironmentVariables["CCP_DOCKER_NVIDIA"];
        std::string additionalOption = args.StartInfo.EnvironmentVariables["CCP_DOCKER_START_OPTION"];
        std::string skipSshSetup = args.StartInfo.EnvironmentVariables["CCP_DOCKER_SKIP_SSH_SETUP"];
        if (!dockerImage.empty())
        {
            taskInfo->IsPrimaryTask = false;
            std::string output;
            dockerImage = String::Join(dockerImage, "\"", "\"");
            isNvidiaDocker = String::Join(isNvidiaDocker, "\"", "\"");
            boost::replace_all(additionalOption, "\"", "\\\"");
            additionalOption = String::Join(additionalOption, "\"", "\"");
            skipSshSetup = String::Join(skipSshSetup, "\"", "\"");
            int ret = System::ExecuteCommandOut(output, "/bin/bash 2>&1", "StartMpiContainer.sh", taskInfo->TaskId, userName, dockerImage, isNvidiaDocker, additionalOption, skipSshSetup);
            if (ret == 0)
            {
                Logger::Info(taskInfo->JobId, taskInfo->TaskId, taskInfo->GetTaskRequeueCount(), "Start MPI container successfully.");
            }
            else
            {
                Logger::Error(taskInfo->JobId, taskInfo->TaskId, taskInfo->GetTaskRequeueCount(), "Start MPI container failed with exitcode {0}. {1}", ret, output);
            }
        }
    }
    else
    {
        if (this->processes.find(taskInfo->ProcessKey) == this->processes.end() &&
            isNewEntry)
        {
            auto process = std::shared_ptr<Process>(new Process(
                taskInfo->JobId,
                taskInfo->TaskId,
                taskInfo->GetTaskRequeueCount(),
                "Task",
                std::move(args.StartInfo.CommandLine),
                std::move(args.StartInfo.StdOutFile),
                std::move(args.StartInfo.StdErrFile),
                std::move(args.StartInfo.StdInFile),
                std::move(args.StartInfo.WorkDirectory),
                userName,
                true,
                std::move(args.StartInfo.Affinity),
                std::move(args.StartInfo.EnvironmentVariables),
                [taskInfo, uri = std::move(callbackUri), this] (
                    int exitCode,
                    std::string&& message,
                    const ProcessStatistics& stat)
                {
                    try
                    {
                        json::value jsonBody;

                        taskInfo->CancelGracefulThread();

                        {
                            WriterLock writerLock(&this->lock);

                            if (taskInfo->Exited)
                            {
                                Logger::Debug(taskInfo->JobId, taskInfo->TaskId, taskInfo->GetTaskRequeueCount(),
                                    "Ended already by EndTask.");
                            }
                            else
                            {
                                taskInfo->Exited = true;
                                taskInfo->ExitCode = exitCode;
                                taskInfo->Message = std::move(message);
                                taskInfo->AssignFromStat(stat);

                                jsonBody = taskInfo->ToCompletionEventArgJson();
                            }
                        }

                        this->ReportTaskCompletion(taskInfo->JobId, taskInfo->TaskId,
                            taskInfo->GetTaskRequeueCount(), jsonBody, uri);

                        // this won't remove the task entry added later as attempt id doesn't match
                        this->jobTaskTable.RemoveTask(taskInfo->JobId, taskInfo->TaskId, taskInfo->GetAttemptId());
                    }
                    catch (const std::exception& ex)
                    {
                        Logger::Error(taskInfo->JobId, taskInfo->TaskId, taskInfo->GetTaskRequeueCount(),
                            "Exception when sending back task result. {0}", ex.what());
                    }

                    Logger::Debug(taskInfo->JobId, taskInfo->TaskId, taskInfo->GetTaskRequeueCount(),
                        "attemptId {0}, processKey {1}, erasing process", taskInfo->GetAttemptId(), taskInfo->ProcessKey);

                    {
                        WriterLock writerLock(&this->lock);

                        // Process will be deleted here.
                        this->processes.erase(taskInfo->ProcessKey);
                    }
                }));

            this->processes[taskInfo->ProcessKey] = process;
            Logger::Debug(
                args.JobId, args.TaskId, taskInfo->GetTaskRequeueCount(),
                "StartTask for ProcessKey {0}, process count {1}", taskInfo->ProcessKey, this->processes.size());

            process->Start(process).then([this, taskInfo] (std::pair<pid_t, pthread_t> ids)
            {
                if (ids.first > 0)
                {
                    Logger::Debug(taskInfo->JobId, taskInfo->TaskId, taskInfo->GetTaskRequeueCount(),
                        "Process started pid {0}, tid {1}", ids.first, ids.second);
                }
            });
        }
        else
        {
            Logger::Warn(taskInfo->JobId, taskInfo->TaskId, taskInfo->GetTaskRequeueCount(),
                "The task has started already.");
        }
    }

    return pplx::task_from_result(json::value());
}