func()

in pkg/jobmgr/task/placement/placement.go [180:372]


func (p *processor) prepareTasksForLaunch(
	ctx context.Context,
	taskIDs []*mesos.TaskID,
	hostname string,
	agentID string,
	selectedPorts []uint32,
) (
	map[string]*lifecyclemgr.LaunchableTaskInfo,
	[]*peloton.TaskID,
	error,
) {
	portsIndex := 0
	taskInfos := make(map[string]*lifecyclemgr.LaunchableTaskInfo)
	skippedTaskIDs := make([]*peloton.TaskID, 0)

	for _, mtaskID := range taskIDs {
		id, instanceID, err := util.ParseJobAndInstanceID(mtaskID.GetValue())
		if err != nil {
			log.WithField("mesos_task_id", mtaskID.GetValue()).
				WithError(err).
				Error("Failed to parse mesos task id")
			continue
		}

		jobID := &peloton.JobID{Value: id}
		ptaskID := &peloton.TaskID{
			Value: util.CreatePelotonTaskID(id, instanceID),
		}
		ptaskIDStr := ptaskID.GetValue()

		cachedJob := p.jobFactory.GetJob(jobID)
		if cachedJob == nil {
			skippedTaskIDs = append(skippedTaskIDs, ptaskID)
			continue
		}

		cachedTask, err := cachedJob.AddTask(ctx, uint32(instanceID))
		if err != nil {
			log.WithError(err).
				WithFields(log.Fields{
					"job_id":      jobID.GetValue(),
					"instance_id": uint32(instanceID),
				}).Error("cannot add and recover task from DB")
			continue
		}

		cachedRuntime, err := cachedTask.GetRuntime(ctx)
		if err != nil {
			log.WithError(err).
				WithFields(log.Fields{
					"job_id":      jobID.GetValue(),
					"instance_id": uint32(instanceID),
				}).Error("cannot fetch task runtime")
			continue
		}

		if cachedRuntime.GetMesosTaskId().GetValue() != mtaskID.GetValue() {
			log.WithFields(log.Fields{
				"job_id":        jobID.GetValue(),
				"instance_id":   uint32(instanceID),
				"mesos_task_id": mtaskID.GetValue(),
			}).Info("skipping launch of old run")
			skippedTaskIDs = append(skippedTaskIDs, ptaskID)
			continue
		}

		taskConfig, configAddOn, err := p.taskConfigV2Ops.GetTaskConfig(
			ctx,
			jobID,
			uint32(instanceID),
			cachedRuntime.GetConfigVersion())
		if err != nil {
			log.WithError(err).
				WithField("task_id", ptaskID.GetValue()).
				Error("not able to get task configuration")
			continue
		}

		var spec *pbpod.PodSpec
		if p.hmVersion.IsV1() {
			// TODO: unify this call with p.taskConfigV2Ops.GetTaskConfig().
			spec, err = p.taskConfigV2Ops.GetPodSpec(
				ctx,
				jobID,
				uint32(instanceID),
				cachedRuntime.GetConfigVersion())
			if err != nil {
				log.WithError(err).
					WithField("task_id", ptaskIDStr).
					Error("not able to get pod spec")
				continue
			}
		}

		runtimeDiff := make(jobmgrcommon.RuntimeDiff)
		if cachedRuntime.GetGoalState() != task.TaskState_KILLED {
			runtimeDiff[jobmgrcommon.HostField] = hostname
			runtimeDiff[jobmgrcommon.AgentIDField] = &mesos.AgentID{
				Value: &agentID,
			}
			runtimeDiff[jobmgrcommon.StateField] = task.TaskState_LAUNCHED
		}

		if selectedPorts != nil {
			// Reset runtime ports to get new ports assignment if placement has ports.
			ports := make(map[string]uint32)
			// Assign selected dynamic port to task per port config.
			for _, portConfig := range taskConfig.GetPorts() {
				if portConfig.GetValue() != 0 {
					// Skip static port.
					continue
				}
				if portsIndex >= len(selectedPorts) {
					// This should never happen.
					log.WithFields(log.Fields{
						"selected_ports": selectedPorts,
						"task_id":        ptaskIDStr,
					}).Error("placement contains less selected ports than required.")
					return nil, nil, errors.New("invalid placement")
				}
				ports[portConfig.GetName()] = selectedPorts[portsIndex]
				portsIndex++
			}
			runtimeDiff[jobmgrcommon.PortsField] = ports
		}

		runtimeDiff[jobmgrcommon.MessageField] = "Add hostname and ports"
		runtimeDiff[jobmgrcommon.ReasonField] = "REASON_UPDATE_OFFER"

		if cachedRuntime.GetGoalState() == task.TaskState_KILLED {
			if cachedRuntime.GetState() != task.TaskState_KILLED {
				// Received placement for task which needs to be killed,
				// retry killing the task.
				p.goalStateDriver.EnqueueTask(
					jobID,
					uint32(instanceID),
					time.Now(),
				)
			}
			// Skip launching of killed tasks.
			skippedTaskIDs = append(skippedTaskIDs, ptaskID)
			log.WithField("task_id", ptaskIDStr).
				Info("skipping launch of killed task")
			continue
		}

		// Patch task runtime with the generated runtime diff.
		retry := 0
		for retry < maxRetryCount {
			_, _, err := cachedJob.PatchTasks(
				ctx,
				map[uint32]jobmgrcommon.RuntimeDiff{
					uint32(instanceID): runtimeDiff,
				},
				false,
			)
			if err == nil {
				runtime, _ := cachedTask.GetRuntime(ctx)
				taskInfos[ptaskIDStr] = &lifecyclemgr.LaunchableTaskInfo{
					TaskInfo: &task.TaskInfo{
						Runtime:    runtime,
						Config:     taskConfig,
						InstanceId: uint32(instanceID),
						JobId:      jobID,
					},
					ConfigAddOn: configAddOn,
					Spec:        spec,
				}
				break
			}

			if common.IsTransientError(err) {
				// TBD add a max retry to bail out after a few retries.
				log.WithError(err).WithFields(log.Fields{
					"job_id":      jobID,
					"instance_id": instanceID,
				}).Warn("retrying update task runtime on transient error")
				retry++
				continue
			}

			log.WithError(err).
				WithFields(log.Fields{
					"job_id":      jobID,
					"instance_id": instanceID,
				}).Error("cannot process placement due to non-transient db error")
			delete(taskInfos, ptaskIDStr)
			break
		}
	}

	return taskInfos, skippedTaskIDs, nil
}