func()

in pkg/jobmgr/tasksvc/handler.go [484:704]


func (m *serviceHandler) Start(
	ctx context.Context,
	body *task.StartRequest) (resp *task.StartResponse, err error) {
	defer func() {
		headers := yarpcutil.GetHeaders(ctx)
		if err != nil || resp.GetError() != nil {
			entry := log.WithField("request", body).
				WithField("headers", headers)

			if err != nil {
				entry = entry.WithError(err)
			}
			if resp.GetError() != nil {
				entry = entry.WithField("start_error", resp.GetError().String())
			}

			entry.Warn("TaskManager.Start failed")
			return
		}

		log.WithField("request", body).
			WithField("headers", headers).
			Info("TaskManager.Start succeeded")
	}()

	m.metrics.TaskAPIStart.Inc(1)
	ctx, cancelFunc := context.WithTimeout(
		ctx,
		_rpcTimeout,
	)
	defer cancelFunc()

	if !m.candidate.IsLeader() {
		m.metrics.TaskStartFail.Inc(1)
		return nil, yarpcerrors.UnavailableErrorf("Task Start API not suppported on non-leader")
	}

	cachedJob := m.jobFactory.AddJob(body.JobId)
	cachedConfig, err := cachedJob.GetConfig(ctx)

	if err != nil {
		log.WithField("job_id", body.JobId.Value).
			WithError(err).
			Error("Failed to get job config")
		m.metrics.TaskStartFail.Inc(1)
		return &task.StartResponse{
			Error: &task.StartResponse_Error{
				NotFound: &pb_errors.JobNotFound{
					Id:      body.JobId,
					Message: err.Error(),
				},
			},
		}, nil
	}

	count := 0
	for {
		jobRuntime, err := cachedJob.GetRuntime(ctx)
		if err != nil {
			log.WithField("job_id", body.JobId.Value).
				WithError(err).
				Info("failed to fetch job runtime while starting tasks")
			m.metrics.TaskStartFail.Inc(1)
			return nil, err
		}

		// batch jobs in terminated state cannot be restarted
		if cachedConfig.GetType() == pb_job.JobType_BATCH {
			if util.IsPelotonJobStateTerminal(jobRuntime.GetState()) {
				log.WithFields(log.Fields{
					"job_id": body.JobId.Value,
					"state":  jobRuntime.GetState().String(),
				}).Info("cannot start tasks in a terminal job")
				m.metrics.TaskStartFail.Inc(1)
				return nil, yarpcerrors.InvalidArgumentErrorf(
					"cannot start tasks in a terminated job")
			}
		}

		jobRuntime.State = pb_job.JobState_PENDING
		jobRuntime.GoalState = goalstateutil.GetDefaultJobGoalState(
			cachedConfig.GetType())

		// update the job runtime
		_, err = cachedJob.CompareAndSetRuntime(ctx, jobRuntime)
		if err == jobmgrcommon.UnexpectedVersionError {
			// concurrency error; retry MaxConcurrencyErrorRetry times
			count = count + 1
			if count < jobmgrcommon.MaxConcurrencyErrorRetry {
				continue
			}
		}

		if err != nil {
			log.WithField("job", body.JobId).
				WithError(err).
				Error("failed to set job runtime in db")
			m.metrics.TaskStartFail.Inc(1)
			return &task.StartResponse{
				Error: &task.StartResponse_Error{
					Failure: &task.TaskStartFailure{
						Message: fmt.Sprintf("task start failed while updating job status %v", err),
					},
				},
			}, nil
		}

		// job runtime is successfully updated, move on
		break
	}

	taskInfos, err := m.getTaskInfosByRangesFromDB(
		ctx, body.GetJobId(), body.GetRanges())
	if err != nil {
		log.WithField("job", body.JobId).
			WithError(err).
			Error("failed to get tasks for job in db")
		m.metrics.TaskStartFail.Inc(1)
		return &task.StartResponse{
			Error: &task.StartResponse_Error{
				OutOfRange: &task.InstanceIdOutOfRange{
					JobId:         body.JobId,
					InstanceCount: cachedConfig.GetInstanceCount(),
				},
			},
		}, nil
	}

	var startedInstanceIds []uint32
	var failedInstanceIds []uint32

	for _, taskInfo := range taskInfos {
		cachedTask, err := cachedJob.AddTask(ctx, taskInfo.GetInstanceId())
		if err != nil {
			log.WithFields(log.Fields{
				"job_id":      body.GetJobId().GetValue(),
				"instance_id": taskInfo.InstanceId,
			}).Info("failed to add task during task start")
			failedInstanceIds = append(failedInstanceIds, taskInfo.InstanceId)
			continue
		}

		count := 0
		for {
			taskRuntime, err := cachedTask.GetRuntime(ctx)
			if err != nil {
				log.WithFields(log.Fields{
					"job_id":      body.GetJobId().GetValue(),
					"instance_id": taskInfo.InstanceId,
				}).Info("failed to fetch runtime during task start")
				failedInstanceIds = append(failedInstanceIds, taskInfo.InstanceId)
				break
			}

			if taskRuntime.GetGoalState() != task.TaskState_KILLED {
				// ignore start request for tasks with non-killed goal state
				log.WithFields(log.Fields{
					"instance_id": taskInfo.InstanceId,
					"job_id":      body.GetJobId().GetValue(),
					"goal_state":  taskRuntime.GetGoalState().String(),
				}).Debug("task was not stopped")
				break
			}

			// Regenerate the task and change the goalstate
			healthState := taskutil.GetInitialHealthState(taskInfo.GetConfig())
			taskutil.RegenerateMesosTaskRuntime(
				body.GetJobId(),
				taskInfo.InstanceId,
				taskRuntime,
				healthState,
			)
			taskRuntime.GoalState =
				jobmgr_task.GetDefaultTaskGoalState(cachedConfig.GetType())
			taskRuntime.Message = "Task start API request"

			// Directly call task level APIs instead of calling job level API
			// as one transaction (like PatchTasks calls) because
			// compare and set calls cannot be batched as one transaction
			// as if task runtime of only one task has changed, then it should
			// not cause the entire transaction to fail and to be retried again.
			_, err = cachedJob.CompareAndSetTask(
				ctx,
				taskInfo.GetInstanceId(),
				taskRuntime,
				false,
			)

			if err == jobmgrcommon.UnexpectedVersionError {
				count = count + 1
				if count < jobmgrcommon.MaxConcurrencyErrorRetry {
					continue
				}
			}

			if err != nil {
				log.WithError(err).
					WithFields(log.Fields{
						"job_id":      body.GetJobId().GetValue(),
						"instance_id": taskInfo.InstanceId,
					}).Info("failed to write runtime during task start")
				failedInstanceIds = append(failedInstanceIds, taskInfo.InstanceId)
			} else {
				startedInstanceIds = append(startedInstanceIds, taskInfo.InstanceId)
			}
			break
		}
	}

	for _, instID := range startedInstanceIds {
		m.goalStateDriver.EnqueueTask(body.GetJobId(), instID, time.Now())
	}
	goalstate.EnqueueJobWithDefaultDelay(
		body.GetJobId(), m.goalStateDriver, cachedJob)

	m.metrics.TaskStart.Inc(1)
	return &task.StartResponse{
		StartedInstanceIds: startedInstanceIds,
		InvalidInstanceIds: failedInstanceIds,
	}, nil
}