func executePipeline()

in sdks/go/pkg/beam/runners/prism/internal/execute.go [116:392]


func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservices.Job) error {
	pipeline := j.Pipeline
	comps := proto.Clone(pipeline.GetComponents()).(*pipepb.Components)

	// TODO, configure the preprocessor from pipeline options.
	// Maybe change these returns to a single struct for convenience and further
	// annotation?

	handlers := []any{
		Combine(CombineCharacteristic{EnableLifting: true}),
		ParDo(ParDoCharacteristic{DisableSDF: true}),
		Runner(RunnerCharacteristic{
			SDKFlatten:   false,
			SDKReshuffle: false,
		}),
	}

	proc := processor{
		transformExecuters: map[string]transformExecuter{},
	}

	var preppers []transformPreparer
	for _, h := range handlers {
		if th, ok := h.(transformPreparer); ok {
			preppers = append(preppers, th)
		}
		if th, ok := h.(transformExecuter); ok {
			for _, urn := range th.ExecuteUrns() {
				proc.transformExecuters[urn] = th
			}
		}
	}

	prepro := newPreprocessor(preppers)

	topo := prepro.preProcessGraph(comps, j)
	ts := comps.GetTransforms()
	pcols := comps.GetPcollections()

	config := engine.Config{EnableRTC: true, EnableSDFSplit: true}
	m := j.PipelineOptions().AsMap()
	if experimentsSlice, ok := m["beam:option:experiments:v1"].([]interface{}); ok {
		for _, exp := range experimentsSlice {
			if expStr, ok := exp.(string); ok {
				if expStr == "prism_disable_rtc" {
					config.EnableRTC = false
					break // Found it, no need to check the rest of the slice
				}
			}
		}
		for _, exp := range experimentsSlice {
			if expStr, ok := exp.(string); ok {
				if expStr == "prism_disable_sdf_split" {
					config.EnableSDFSplit = false
					break // Found it, no need to check the rest of the slice
				}
			}
		}
	}

	if streaming, ok := m["beam:option:streaming:v1"].(bool); ok {
		config.StreamingMode = streaming
	}

	// Set StreamingMode to true if there is any unbounded PCollection.
	for _, pcoll := range pcols {
		if pcoll.GetIsBounded() == pipepb.IsBounded_UNBOUNDED {
			config.StreamingMode = true
			break
		}
	}

	em := engine.NewElementManager(config)

	// TODO move this loop and code into the preprocessor instead.
	stages := map[string]*stage{}
	var impulses []string

	for i, stage := range topo {
		tid := stage.transforms[0]
		t := ts[tid]
		urn := t.GetSpec().GetUrn()
		stage.exe = proc.transformExecuters[urn]

		stage.ID = fmt.Sprintf("stage-%03d", i)
		wk := wks[stage.envID]

		switch stage.envID {
		case "": // Runner Transforms
			var onlyOut string
			for _, out := range t.GetOutputs() {
				onlyOut = out
			}
			stage.OutputsToCoders = map[string]engine.PColInfo{}
			coders := map[string]*pipepb.Coder{}
			makeWindowedValueCoder(onlyOut, comps, coders)

			col := comps.GetPcollections()[onlyOut]
			ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
			winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)

			var kd func(io.Reader) []byte
			if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
				kd = collectionPullDecoder(kcid, coders, comps)
			}
			stage.OutputsToCoders[onlyOut] = engine.PColInfo{
				GlobalID:    onlyOut,
				WindowCoder: winCoder,
				WDec:        wDec,
				WEnc:        wEnc,
				EDec:        ed,
				KeyDec:      kd,
			}

			// There's either 0, 1 or many inputs, but they should be all the same
			// so break after the first one.
			for _, global := range t.GetInputs() {
				col := comps.GetPcollections()[global]
				ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
				winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
				stage.inputInfo = engine.PColInfo{
					GlobalID:    global,
					WindowCoder: winCoder,
					WDec:        wDec,
					WEnc:        wEnc,
					EDec:        ed,
				}
				break
			}

			switch urn {
			case urns.TransformGBK:
				em.AddStage(stage.ID, []string{getOnlyValue(t.GetInputs())}, []string{getOnlyValue(t.GetOutputs())}, nil)
				for _, global := range t.GetInputs() {
					col := comps.GetPcollections()[global]
					ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
					winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)

					var kd func(io.Reader) []byte
					if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
						kd = collectionPullDecoder(kcid, coders, comps)
					}
					stage.inputInfo = engine.PColInfo{
						GlobalID:    global,
						WindowCoder: winCoder,
						WDec:        wDec,
						WEnc:        wEnc,
						EDec:        ed,
						KeyDec:      kd,
					}
				}
				ws := windowingStrategy(comps, tid)
				em.StageAggregates(stage.ID, engine.WinStrat{
					AllowedLateness: time.Duration(ws.GetAllowedLateness()) * time.Millisecond,
					Accumulating:    pipepb.AccumulationMode_ACCUMULATING == ws.GetAccumulationMode(),
					Trigger:         buildTrigger(ws.GetTrigger()),
				})
			case urns.TransformImpulse:
				impulses = append(impulses, stage.ID)
				em.AddStage(stage.ID, nil, []string{getOnlyValue(t.GetOutputs())}, nil)
			case urns.TransformTestStream:
				// Add a synthetic stage that should largely be unused.
				em.AddStage(stage.ID, nil, maps.Values(t.GetOutputs()), nil)

				for pcolID, info := range stage.OutputsToCoders {
					em.RegisterPColInfo(pcolID, info)
				}

				// Decode the test stream, and convert it to the various events for the ElementManager.
				var pyld pipepb.TestStreamPayload
				if err := proto.Unmarshal(t.GetSpec().GetPayload(), &pyld); err != nil {
					return fmt.Errorf("prism error building stage %v - decoding TestStreamPayload: \n%w", stage.ID, err)
				}

				tsb := em.AddTestStream(stage.ID, t.Outputs)
				for _, e := range pyld.GetEvents() {
					switch ev := e.GetEvent().(type) {
					case *pipepb.TestStreamPayload_Event_ElementEvent:
						var elms []engine.TestStreamElement
						for _, e := range ev.ElementEvent.GetElements() {
							// Encoded bytes are already handled in handleTestStream if needed.
							elms = append(elms, engine.TestStreamElement{Encoded: e.GetEncodedElement(), EventTime: mtime.FromMilliseconds(e.GetTimestamp())})
						}
						tsb.AddElementEvent(ev.ElementEvent.GetTag(), elms)
					case *pipepb.TestStreamPayload_Event_WatermarkEvent:
						tsb.AddWatermarkEvent(ev.WatermarkEvent.GetTag(), mtime.FromMilliseconds(ev.WatermarkEvent.GetNewWatermark()))
					case *pipepb.TestStreamPayload_Event_ProcessingTimeEvent:
						if ev.ProcessingTimeEvent.GetAdvanceDuration() == int64(mtime.MaxTimestamp) {
							// TODO: Determine the SDK common formalism for setting processing time to infinity.
							tsb.AddProcessingTimeEvent(time.Duration(mtime.MaxTimestamp))
						} else {
							tsb.AddProcessingTimeEvent(time.Duration(ev.ProcessingTimeEvent.GetAdvanceDuration()) * time.Millisecond)
						}

					default:
						return fmt.Errorf("prism error building stage %v - unknown TestStream event type: %T", stage.ID, ev)
					}
				}

			case urns.TransformFlatten:
				inputs := maps.Values(t.GetInputs())
				sort.Strings(inputs)
				em.AddStage(stage.ID, inputs, []string{getOnlyValue(t.GetOutputs())}, nil)
			}
			stages[stage.ID] = stage
		case wk.Env:
			if err := buildDescriptor(stage, comps, wk, em); err != nil {
				return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err)
			}
			stages[stage.ID] = stage
			outputs := maps.Keys(stage.OutputsToCoders)
			sort.Strings(outputs)
			em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs)
			if stage.stateful {
				em.StageStateful(stage.ID, stage.stateTypeLen)
			}
			if stage.onWindowExpiration.TimerFamily != "" {
				slog.Debug("OnWindowExpiration", slog.String("stage", stage.ID), slog.Any("values", stage.onWindowExpiration))
				em.StageOnWindowExpiration(stage.ID, stage.onWindowExpiration)
			}
			if len(stage.processingTimeTimers) > 0 {
				em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers)
			}
			stage.sdfSplittable = config.EnableSDFSplit
		default:
			return fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId())
		}
	}

	// Prime the initial impulses, since we now know what consumes them.
	for _, id := range impulses {
		em.Impulse(id)
	}

	// Use an errgroup to limit max parallelism for the pipeline.
	eg, egctx := errgroup.WithContext(ctx)
	eg.SetLimit(8)

	var instID uint64
	bundles := em.Bundles(egctx, j.CancelFn, func() string {
		return fmt.Sprintf("inst%03d", atomic.AddUint64(&instID, 1))
	})

	// Create a new ticker that fires every 60 seconds.
	ticker := time.NewTicker(60 * time.Second)
	// Ensure the ticker is stopped when the function returns to prevent a goroutine leak.
	defer ticker.Stop()
	for {
		select {
		case <-ctx.Done():
			err := context.Cause(ctx)
			j.Logger.Debug("context canceled", slog.Any("cause", err))
			return err
		case rb, ok := <-bundles:
			if !ok {
				err := eg.Wait()
				j.Logger.Info("pipeline done!", slog.String("job", j.String()))
				j.Logger.Debug("finished state", slog.String("job", j.String()), slog.Any("error", err), slog.String("stages", em.DumpStages()))
				return err
			}
			eg.Go(func() error {
				s := stages[rb.StageID]
				wk := wks[s.envID]
				if err := s.Execute(ctx, j, wk, comps, em, rb); err != nil {
					// Ensure we clean up on bundle failure
					j.Logger.Error("Bundle Failed.", slog.Any("error", err))
					em.FailBundle(rb)
					return err
				}
				return nil
			})
		// Log a heartbeat every 60 seconds
		case <-ticker.C:
			j.Logger.Info("pipeline is running", slog.String("job", j.String()))
		}
	}
}