in pkg/jobmgr/tasksvc/handler.go [484:704]
func (m *serviceHandler) Start(
ctx context.Context,
body *task.StartRequest) (resp *task.StartResponse, err error) {
defer func() {
headers := yarpcutil.GetHeaders(ctx)
if err != nil || resp.GetError() != nil {
entry := log.WithField("request", body).
WithField("headers", headers)
if err != nil {
entry = entry.WithError(err)
}
if resp.GetError() != nil {
entry = entry.WithField("start_error", resp.GetError().String())
}
entry.Warn("TaskManager.Start failed")
return
}
log.WithField("request", body).
WithField("headers", headers).
Info("TaskManager.Start succeeded")
}()
m.metrics.TaskAPIStart.Inc(1)
ctx, cancelFunc := context.WithTimeout(
ctx,
_rpcTimeout,
)
defer cancelFunc()
if !m.candidate.IsLeader() {
m.metrics.TaskStartFail.Inc(1)
return nil, yarpcerrors.UnavailableErrorf("Task Start API not suppported on non-leader")
}
cachedJob := m.jobFactory.AddJob(body.JobId)
cachedConfig, err := cachedJob.GetConfig(ctx)
if err != nil {
log.WithField("job_id", body.JobId.Value).
WithError(err).
Error("Failed to get job config")
m.metrics.TaskStartFail.Inc(1)
return &task.StartResponse{
Error: &task.StartResponse_Error{
NotFound: &pb_errors.JobNotFound{
Id: body.JobId,
Message: err.Error(),
},
},
}, nil
}
count := 0
for {
jobRuntime, err := cachedJob.GetRuntime(ctx)
if err != nil {
log.WithField("job_id", body.JobId.Value).
WithError(err).
Info("failed to fetch job runtime while starting tasks")
m.metrics.TaskStartFail.Inc(1)
return nil, err
}
// batch jobs in terminated state cannot be restarted
if cachedConfig.GetType() == pb_job.JobType_BATCH {
if util.IsPelotonJobStateTerminal(jobRuntime.GetState()) {
log.WithFields(log.Fields{
"job_id": body.JobId.Value,
"state": jobRuntime.GetState().String(),
}).Info("cannot start tasks in a terminal job")
m.metrics.TaskStartFail.Inc(1)
return nil, yarpcerrors.InvalidArgumentErrorf(
"cannot start tasks in a terminated job")
}
}
jobRuntime.State = pb_job.JobState_PENDING
jobRuntime.GoalState = goalstateutil.GetDefaultJobGoalState(
cachedConfig.GetType())
// update the job runtime
_, err = cachedJob.CompareAndSetRuntime(ctx, jobRuntime)
if err == jobmgrcommon.UnexpectedVersionError {
// concurrency error; retry MaxConcurrencyErrorRetry times
count = count + 1
if count < jobmgrcommon.MaxConcurrencyErrorRetry {
continue
}
}
if err != nil {
log.WithField("job", body.JobId).
WithError(err).
Error("failed to set job runtime in db")
m.metrics.TaskStartFail.Inc(1)
return &task.StartResponse{
Error: &task.StartResponse_Error{
Failure: &task.TaskStartFailure{
Message: fmt.Sprintf("task start failed while updating job status %v", err),
},
},
}, nil
}
// job runtime is successfully updated, move on
break
}
taskInfos, err := m.getTaskInfosByRangesFromDB(
ctx, body.GetJobId(), body.GetRanges())
if err != nil {
log.WithField("job", body.JobId).
WithError(err).
Error("failed to get tasks for job in db")
m.metrics.TaskStartFail.Inc(1)
return &task.StartResponse{
Error: &task.StartResponse_Error{
OutOfRange: &task.InstanceIdOutOfRange{
JobId: body.JobId,
InstanceCount: cachedConfig.GetInstanceCount(),
},
},
}, nil
}
var startedInstanceIds []uint32
var failedInstanceIds []uint32
for _, taskInfo := range taskInfos {
cachedTask, err := cachedJob.AddTask(ctx, taskInfo.GetInstanceId())
if err != nil {
log.WithFields(log.Fields{
"job_id": body.GetJobId().GetValue(),
"instance_id": taskInfo.InstanceId,
}).Info("failed to add task during task start")
failedInstanceIds = append(failedInstanceIds, taskInfo.InstanceId)
continue
}
count := 0
for {
taskRuntime, err := cachedTask.GetRuntime(ctx)
if err != nil {
log.WithFields(log.Fields{
"job_id": body.GetJobId().GetValue(),
"instance_id": taskInfo.InstanceId,
}).Info("failed to fetch runtime during task start")
failedInstanceIds = append(failedInstanceIds, taskInfo.InstanceId)
break
}
if taskRuntime.GetGoalState() != task.TaskState_KILLED {
// ignore start request for tasks with non-killed goal state
log.WithFields(log.Fields{
"instance_id": taskInfo.InstanceId,
"job_id": body.GetJobId().GetValue(),
"goal_state": taskRuntime.GetGoalState().String(),
}).Debug("task was not stopped")
break
}
// Regenerate the task and change the goalstate
healthState := taskutil.GetInitialHealthState(taskInfo.GetConfig())
taskutil.RegenerateMesosTaskRuntime(
body.GetJobId(),
taskInfo.InstanceId,
taskRuntime,
healthState,
)
taskRuntime.GoalState =
jobmgr_task.GetDefaultTaskGoalState(cachedConfig.GetType())
taskRuntime.Message = "Task start API request"
// Directly call task level APIs instead of calling job level API
// as one transaction (like PatchTasks calls) because
// compare and set calls cannot be batched as one transaction
// as if task runtime of only one task has changed, then it should
// not cause the entire transaction to fail and to be retried again.
_, err = cachedJob.CompareAndSetTask(
ctx,
taskInfo.GetInstanceId(),
taskRuntime,
false,
)
if err == jobmgrcommon.UnexpectedVersionError {
count = count + 1
if count < jobmgrcommon.MaxConcurrencyErrorRetry {
continue
}
}
if err != nil {
log.WithError(err).
WithFields(log.Fields{
"job_id": body.GetJobId().GetValue(),
"instance_id": taskInfo.InstanceId,
}).Info("failed to write runtime during task start")
failedInstanceIds = append(failedInstanceIds, taskInfo.InstanceId)
} else {
startedInstanceIds = append(startedInstanceIds, taskInfo.InstanceId)
}
break
}
}
for _, instID := range startedInstanceIds {
m.goalStateDriver.EnqueueTask(body.GetJobId(), instID, time.Now())
}
goalstate.EnqueueJobWithDefaultDelay(
body.GetJobId(), m.goalStateDriver, cachedJob)
m.metrics.TaskStart.Inc(1)
return &task.StartResponse{
StartedInstanceIds: startedInstanceIds,
InvalidInstanceIds: failedInstanceIds,
}, nil
}