in sdks/go/pkg/beam/runners/prism/internal/execute.go [116:392]
func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservices.Job) error {
pipeline := j.Pipeline
comps := proto.Clone(pipeline.GetComponents()).(*pipepb.Components)
// TODO, configure the preprocessor from pipeline options.
// Maybe change these returns to a single struct for convenience and further
// annotation?
handlers := []any{
Combine(CombineCharacteristic{EnableLifting: true}),
ParDo(ParDoCharacteristic{DisableSDF: true}),
Runner(RunnerCharacteristic{
SDKFlatten: false,
SDKReshuffle: false,
}),
}
proc := processor{
transformExecuters: map[string]transformExecuter{},
}
var preppers []transformPreparer
for _, h := range handlers {
if th, ok := h.(transformPreparer); ok {
preppers = append(preppers, th)
}
if th, ok := h.(transformExecuter); ok {
for _, urn := range th.ExecuteUrns() {
proc.transformExecuters[urn] = th
}
}
}
prepro := newPreprocessor(preppers)
topo := prepro.preProcessGraph(comps, j)
ts := comps.GetTransforms()
pcols := comps.GetPcollections()
config := engine.Config{EnableRTC: true, EnableSDFSplit: true}
m := j.PipelineOptions().AsMap()
if experimentsSlice, ok := m["beam:option:experiments:v1"].([]interface{}); ok {
for _, exp := range experimentsSlice {
if expStr, ok := exp.(string); ok {
if expStr == "prism_disable_rtc" {
config.EnableRTC = false
break // Found it, no need to check the rest of the slice
}
}
}
for _, exp := range experimentsSlice {
if expStr, ok := exp.(string); ok {
if expStr == "prism_disable_sdf_split" {
config.EnableSDFSplit = false
break // Found it, no need to check the rest of the slice
}
}
}
}
if streaming, ok := m["beam:option:streaming:v1"].(bool); ok {
config.StreamingMode = streaming
}
// Set StreamingMode to true if there is any unbounded PCollection.
for _, pcoll := range pcols {
if pcoll.GetIsBounded() == pipepb.IsBounded_UNBOUNDED {
config.StreamingMode = true
break
}
}
em := engine.NewElementManager(config)
// TODO move this loop and code into the preprocessor instead.
stages := map[string]*stage{}
var impulses []string
for i, stage := range topo {
tid := stage.transforms[0]
t := ts[tid]
urn := t.GetSpec().GetUrn()
stage.exe = proc.transformExecuters[urn]
stage.ID = fmt.Sprintf("stage-%03d", i)
wk := wks[stage.envID]
switch stage.envID {
case "": // Runner Transforms
var onlyOut string
for _, out := range t.GetOutputs() {
onlyOut = out
}
stage.OutputsToCoders = map[string]engine.PColInfo{}
coders := map[string]*pipepb.Coder{}
makeWindowedValueCoder(onlyOut, comps, coders)
col := comps.GetPcollections()[onlyOut]
ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
var kd func(io.Reader) []byte
if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
kd = collectionPullDecoder(kcid, coders, comps)
}
stage.OutputsToCoders[onlyOut] = engine.PColInfo{
GlobalID: onlyOut,
WindowCoder: winCoder,
WDec: wDec,
WEnc: wEnc,
EDec: ed,
KeyDec: kd,
}
// There's either 0, 1 or many inputs, but they should be all the same
// so break after the first one.
for _, global := range t.GetInputs() {
col := comps.GetPcollections()[global]
ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
stage.inputInfo = engine.PColInfo{
GlobalID: global,
WindowCoder: winCoder,
WDec: wDec,
WEnc: wEnc,
EDec: ed,
}
break
}
switch urn {
case urns.TransformGBK:
em.AddStage(stage.ID, []string{getOnlyValue(t.GetInputs())}, []string{getOnlyValue(t.GetOutputs())}, nil)
for _, global := range t.GetInputs() {
col := comps.GetPcollections()[global]
ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
var kd func(io.Reader) []byte
if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
kd = collectionPullDecoder(kcid, coders, comps)
}
stage.inputInfo = engine.PColInfo{
GlobalID: global,
WindowCoder: winCoder,
WDec: wDec,
WEnc: wEnc,
EDec: ed,
KeyDec: kd,
}
}
ws := windowingStrategy(comps, tid)
em.StageAggregates(stage.ID, engine.WinStrat{
AllowedLateness: time.Duration(ws.GetAllowedLateness()) * time.Millisecond,
Accumulating: pipepb.AccumulationMode_ACCUMULATING == ws.GetAccumulationMode(),
Trigger: buildTrigger(ws.GetTrigger()),
})
case urns.TransformImpulse:
impulses = append(impulses, stage.ID)
em.AddStage(stage.ID, nil, []string{getOnlyValue(t.GetOutputs())}, nil)
case urns.TransformTestStream:
// Add a synthetic stage that should largely be unused.
em.AddStage(stage.ID, nil, maps.Values(t.GetOutputs()), nil)
for pcolID, info := range stage.OutputsToCoders {
em.RegisterPColInfo(pcolID, info)
}
// Decode the test stream, and convert it to the various events for the ElementManager.
var pyld pipepb.TestStreamPayload
if err := proto.Unmarshal(t.GetSpec().GetPayload(), &pyld); err != nil {
return fmt.Errorf("prism error building stage %v - decoding TestStreamPayload: \n%w", stage.ID, err)
}
tsb := em.AddTestStream(stage.ID, t.Outputs)
for _, e := range pyld.GetEvents() {
switch ev := e.GetEvent().(type) {
case *pipepb.TestStreamPayload_Event_ElementEvent:
var elms []engine.TestStreamElement
for _, e := range ev.ElementEvent.GetElements() {
// Encoded bytes are already handled in handleTestStream if needed.
elms = append(elms, engine.TestStreamElement{Encoded: e.GetEncodedElement(), EventTime: mtime.FromMilliseconds(e.GetTimestamp())})
}
tsb.AddElementEvent(ev.ElementEvent.GetTag(), elms)
case *pipepb.TestStreamPayload_Event_WatermarkEvent:
tsb.AddWatermarkEvent(ev.WatermarkEvent.GetTag(), mtime.FromMilliseconds(ev.WatermarkEvent.GetNewWatermark()))
case *pipepb.TestStreamPayload_Event_ProcessingTimeEvent:
if ev.ProcessingTimeEvent.GetAdvanceDuration() == int64(mtime.MaxTimestamp) {
// TODO: Determine the SDK common formalism for setting processing time to infinity.
tsb.AddProcessingTimeEvent(time.Duration(mtime.MaxTimestamp))
} else {
tsb.AddProcessingTimeEvent(time.Duration(ev.ProcessingTimeEvent.GetAdvanceDuration()) * time.Millisecond)
}
default:
return fmt.Errorf("prism error building stage %v - unknown TestStream event type: %T", stage.ID, ev)
}
}
case urns.TransformFlatten:
inputs := maps.Values(t.GetInputs())
sort.Strings(inputs)
em.AddStage(stage.ID, inputs, []string{getOnlyValue(t.GetOutputs())}, nil)
}
stages[stage.ID] = stage
case wk.Env:
if err := buildDescriptor(stage, comps, wk, em); err != nil {
return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err)
}
stages[stage.ID] = stage
outputs := maps.Keys(stage.OutputsToCoders)
sort.Strings(outputs)
em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs)
if stage.stateful {
em.StageStateful(stage.ID, stage.stateTypeLen)
}
if stage.onWindowExpiration.TimerFamily != "" {
slog.Debug("OnWindowExpiration", slog.String("stage", stage.ID), slog.Any("values", stage.onWindowExpiration))
em.StageOnWindowExpiration(stage.ID, stage.onWindowExpiration)
}
if len(stage.processingTimeTimers) > 0 {
em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers)
}
stage.sdfSplittable = config.EnableSDFSplit
default:
return fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId())
}
}
// Prime the initial impulses, since we now know what consumes them.
for _, id := range impulses {
em.Impulse(id)
}
// Use an errgroup to limit max parallelism for the pipeline.
eg, egctx := errgroup.WithContext(ctx)
eg.SetLimit(8)
var instID uint64
bundles := em.Bundles(egctx, j.CancelFn, func() string {
return fmt.Sprintf("inst%03d", atomic.AddUint64(&instID, 1))
})
// Create a new ticker that fires every 60 seconds.
ticker := time.NewTicker(60 * time.Second)
// Ensure the ticker is stopped when the function returns to prevent a goroutine leak.
defer ticker.Stop()
for {
select {
case <-ctx.Done():
err := context.Cause(ctx)
j.Logger.Debug("context canceled", slog.Any("cause", err))
return err
case rb, ok := <-bundles:
if !ok {
err := eg.Wait()
j.Logger.Info("pipeline done!", slog.String("job", j.String()))
j.Logger.Debug("finished state", slog.String("job", j.String()), slog.Any("error", err), slog.String("stages", em.DumpStages()))
return err
}
eg.Go(func() error {
s := stages[rb.StageID]
wk := wks[s.envID]
if err := s.Execute(ctx, j, wk, comps, em, rb); err != nil {
// Ensure we clean up on bundle failure
j.Logger.Error("Bundle Failed.", slog.Any("error", err))
em.FailBundle(rb)
return err
}
return nil
})
// Log a heartbeat every 60 seconds
case <-ticker.C:
j.Logger.Info("pipeline is running", slog.String("job", j.String()))
}
}
}