in pkg/jobmgr/tasksvc/handler.go [780:918]
func (m *serviceHandler) Stop(
ctx context.Context,
body *task.StopRequest) (resp *task.StopResponse, err error) {
defer func() {
headers := yarpcutil.GetHeaders(ctx)
if err != nil || resp.GetError() != nil {
entry := log.WithField("request", body).
WithField("headers", headers)
if err != nil {
entry = entry.WithError(err)
}
if resp.GetError() != nil {
entry = entry.WithField("stop_error", resp.GetError().String())
}
entry.Warn("TaskManager.Stop failed")
return
}
log.WithField("request", body).
WithField("headers", headers).
Info("TaskManager.Stop succeeded")
}()
m.metrics.TaskAPIStop.Inc(1)
ctx, cancelFunc := context.WithTimeout(
ctx,
_rpcTimeout,
)
defer cancelFunc()
if !m.candidate.IsLeader() {
m.metrics.TaskStopFail.Inc(1)
return nil, yarpcerrors.UnavailableErrorf("Task Stop API not suppported on non-leader")
}
cachedJob := m.jobFactory.AddJob(body.JobId)
cachedConfig, err := cachedJob.GetConfig(ctx)
if err != nil {
log.WithField("job_id", body.JobId).
WithError(err).
Error("Failed to get job config")
m.metrics.TaskStopFail.Inc(1)
return &task.StopResponse{
Error: &task.StopResponse_Error{
NotFound: &pb_errors.JobNotFound{
Id: body.JobId,
Message: err.Error(),
},
},
}, nil
}
taskRange := body.GetRanges()
if len(taskRange) == 0 || (len(taskRange) == 1 && taskRange[0].From == 0 && taskRange[0].To >= cachedConfig.GetInstanceCount()) {
// Stop all tasks in a job, stop entire job instead of task by task
log.WithField("job_id", body.GetJobId().GetValue()).
Info("stopping all tasks in the job")
return m.stopJob(ctx, body.GetJobId(), cachedConfig.GetInstanceCount())
}
taskInfos, err := m.getTaskInfosByRangesFromDB(
ctx, body.GetJobId(), taskRange)
if err != nil {
log.WithField("job", body.JobId).
WithError(err).
Error("failed to get tasks for job in db")
m.metrics.TaskStopFail.Inc(1)
return &task.StopResponse{
Error: &task.StopResponse_Error{
OutOfRange: &task.InstanceIdOutOfRange{
JobId: body.JobId,
InstanceCount: cachedConfig.GetInstanceCount(),
},
},
}, nil
}
// tasksToKill only includes task ids whose goal state update succeeds.
var stoppedInstanceIds []uint32
var failedInstanceIds []uint32
var instanceIds []uint32
runtimeDiffs := make(map[uint32]jobmgrcommon.RuntimeDiff)
// Persist KILLED goalstate for tasks in db.
for _, taskInfo := range taskInfos {
// Skip update task goalstate if it is already KILLED.
if taskInfo.GetRuntime().GoalState == task.TaskState_KILLED {
continue
}
runtimeDiff := jobmgrcommon.RuntimeDiff{
jobmgrcommon.GoalStateField: task.TaskState_KILLED,
jobmgrcommon.MessageField: "Task stop API request",
jobmgrcommon.ReasonField: "",
jobmgrcommon.TerminationStatusField: &task.TerminationStatus{
Reason: task.TerminationStatus_TERMINATION_STATUS_REASON_KILLED_ON_REQUEST,
},
}
runtimeDiffs[taskInfo.InstanceId] = runtimeDiff
instanceIds = append(instanceIds, taskInfo.InstanceId)
}
_, _, err = cachedJob.PatchTasks(ctx, runtimeDiffs, false)
if err != nil {
log.WithError(err).
WithField("instance_ids", instanceIds).
WithField("job_id", body.GetJobId().GetValue()).
Error("failed to updated killed goalstate")
failedInstanceIds = instanceIds
m.metrics.TaskStopFail.Inc(1)
} else {
stoppedInstanceIds = instanceIds
m.metrics.TaskStop.Inc(1)
}
for _, instID := range stoppedInstanceIds {
m.goalStateDriver.EnqueueTask(body.GetJobId(), instID, time.Now())
}
goalstate.EnqueueJobWithDefaultDelay(
body.GetJobId(), m.goalStateDriver, cachedJob)
if err != nil {
return &task.StopResponse{
Error: &task.StopResponse_Error{
UpdateError: &task.TaskUpdateError{
Message: fmt.Sprintf("Goalstate update failed for %v", err),
},
},
StoppedInstanceIds: stoppedInstanceIds,
}, nil
}
return &task.StopResponse{
StoppedInstanceIds: stoppedInstanceIds,
InvalidInstanceIds: failedInstanceIds,
}, nil
}