in sdks/go/pkg/beam/runners/prism/internal/execute.go [110:364]
func executePipeline(ctx context.Context, wks map[string]*worker.W, j *jobservices.Job) error {
pipeline := j.Pipeline
comps := proto.Clone(pipeline.GetComponents()).(*pipepb.Components)
// TODO, configure the preprocessor from pipeline options.
// Maybe change these returns to a single struct for convenience and further
// annotation?
handlers := []any{
Combine(CombineCharacteristic{EnableLifting: true}),
ParDo(ParDoCharacteristic{DisableSDF: true}),
Runner(RunnerCharacteristic{
SDKFlatten: false,
SDKReshuffle: false,
}),
}
proc := processor{
transformExecuters: map[string]transformExecuter{},
}
var preppers []transformPreparer
for _, h := range handlers {
if th, ok := h.(transformPreparer); ok {
preppers = append(preppers, th)
}
if th, ok := h.(transformExecuter); ok {
for _, urn := range th.ExecuteUrns() {
proc.transformExecuters[urn] = th
}
}
}
prepro := newPreprocessor(preppers)
topo := prepro.preProcessGraph(comps, j)
ts := comps.GetTransforms()
em := engine.NewElementManager(engine.Config{})
// TODO move this loop and code into the preprocessor instead.
stages := map[string]*stage{}
var impulses []string
for i, stage := range topo {
tid := stage.transforms[0]
t := ts[tid]
urn := t.GetSpec().GetUrn()
stage.exe = proc.transformExecuters[urn]
stage.ID = fmt.Sprintf("stage-%03d", i)
wk := wks[stage.envID]
switch stage.envID {
case "": // Runner Transforms
var onlyOut string
for _, out := range t.GetOutputs() {
onlyOut = out
}
stage.OutputsToCoders = map[string]engine.PColInfo{}
coders := map[string]*pipepb.Coder{}
makeWindowedValueCoder(onlyOut, comps, coders)
col := comps.GetPcollections()[onlyOut]
ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
var kd func(io.Reader) []byte
if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
kd = collectionPullDecoder(kcid, coders, comps)
}
stage.OutputsToCoders[onlyOut] = engine.PColInfo{
GlobalID: onlyOut,
WindowCoder: winCoder,
WDec: wDec,
WEnc: wEnc,
EDec: ed,
KeyDec: kd,
}
// There's either 0, 1 or many inputs, but they should be all the same
// so break after the first one.
for _, global := range t.GetInputs() {
col := comps.GetPcollections()[global]
ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
stage.inputInfo = engine.PColInfo{
GlobalID: global,
WindowCoder: winCoder,
WDec: wDec,
WEnc: wEnc,
EDec: ed,
}
break
}
switch urn {
case urns.TransformGBK:
em.AddStage(stage.ID, []string{getOnlyValue(t.GetInputs())}, []string{getOnlyValue(t.GetOutputs())}, nil)
for _, global := range t.GetInputs() {
col := comps.GetPcollections()[global]
ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
var kd func(io.Reader) []byte
if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
kd = collectionPullDecoder(kcid, coders, comps)
}
stage.inputInfo = engine.PColInfo{
GlobalID: global,
WindowCoder: winCoder,
WDec: wDec,
WEnc: wEnc,
EDec: ed,
KeyDec: kd,
}
}
ws := windowingStrategy(comps, tid)
em.StageAggregates(stage.ID, engine.WinStrat{
AllowedLateness: time.Duration(ws.GetAllowedLateness()) * time.Millisecond,
Accumulating: pipepb.AccumulationMode_ACCUMULATING == ws.GetAccumulationMode(),
Trigger: buildTrigger(ws.GetTrigger()),
})
case urns.TransformImpulse:
impulses = append(impulses, stage.ID)
em.AddStage(stage.ID, nil, []string{getOnlyValue(t.GetOutputs())}, nil)
case urns.TransformTestStream:
// Add a synthetic stage that should largely be unused.
em.AddStage(stage.ID, nil, maps.Values(t.GetOutputs()), nil)
// Decode the test stream, and convert it to the various events for the ElementManager.
var pyld pipepb.TestStreamPayload
if err := proto.Unmarshal(t.GetSpec().GetPayload(), &pyld); err != nil {
return fmt.Errorf("prism error building stage %v - decoding TestStreamPayload: \n%w", stage.ID, err)
}
// Ensure awareness of the coder used for the teststream.
cID, err := lpUnknownCoders(pyld.GetCoderId(), coders, comps.GetCoders())
if err != nil {
panic(err)
}
mayLP := func(v []byte) []byte {
//slog.Warn("teststream bytes", "value", string(v), "bytes", v)
return v
}
// Hack for Java Strings in test stream, since it doesn't encode them correctly.
forceLP := cID == "StringUtf8Coder" || cID != pyld.GetCoderId()
if forceLP {
// slog.Warn("recoding TestStreamValue", "cID", cID, "newUrn", coders[cID].GetSpec().GetUrn(), "payloadCoder", pyld.GetCoderId(), "oldUrn", coders[pyld.GetCoderId()].GetSpec().GetUrn())
// The coder needed length prefixing. For simplicity, add a length prefix to each
// encoded element, since we will be sending a length prefixed coder to consume
// this anyway. This is simpler than trying to find all the re-written coders after the fact.
mayLP = func(v []byte) []byte {
var buf bytes.Buffer
if err := coder.EncodeVarInt((int64)(len(v)), &buf); err != nil {
panic(err)
}
if _, err := buf.Write(v); err != nil {
panic(err)
}
//slog.Warn("teststream bytes - after LP", "value", string(v), "bytes", buf.Bytes())
return buf.Bytes()
}
}
tsb := em.AddTestStream(stage.ID, t.Outputs)
for _, e := range pyld.GetEvents() {
switch ev := e.GetEvent().(type) {
case *pipepb.TestStreamPayload_Event_ElementEvent:
var elms []engine.TestStreamElement
for _, e := range ev.ElementEvent.GetElements() {
elms = append(elms, engine.TestStreamElement{Encoded: mayLP(e.GetEncodedElement()), EventTime: mtime.FromMilliseconds(e.GetTimestamp())})
}
tsb.AddElementEvent(ev.ElementEvent.GetTag(), elms)
case *pipepb.TestStreamPayload_Event_WatermarkEvent:
tsb.AddWatermarkEvent(ev.WatermarkEvent.GetTag(), mtime.FromMilliseconds(ev.WatermarkEvent.GetNewWatermark()))
case *pipepb.TestStreamPayload_Event_ProcessingTimeEvent:
if ev.ProcessingTimeEvent.GetAdvanceDuration() == int64(mtime.MaxTimestamp) {
// TODO: Determine the SDK common formalism for setting processing time to infinity.
tsb.AddProcessingTimeEvent(time.Duration(mtime.MaxTimestamp))
} else {
tsb.AddProcessingTimeEvent(time.Duration(ev.ProcessingTimeEvent.GetAdvanceDuration()) * time.Millisecond)
}
default:
return fmt.Errorf("prism error building stage %v - unknown TestStream event type: %T", stage.ID, ev)
}
}
case urns.TransformFlatten:
inputs := maps.Values(t.GetInputs())
sort.Strings(inputs)
em.AddStage(stage.ID, inputs, []string{getOnlyValue(t.GetOutputs())}, nil)
}
stages[stage.ID] = stage
case wk.Env:
if err := buildDescriptor(stage, comps, wk, em); err != nil {
return fmt.Errorf("prism error building stage %v: \n%w", stage.ID, err)
}
stages[stage.ID] = stage
j.Logger.Debug("pipelineBuild", slog.Group("stage", slog.String("ID", stage.ID), slog.String("transformName", t.GetUniqueName())))
outputs := maps.Keys(stage.OutputsToCoders)
sort.Strings(outputs)
em.AddStage(stage.ID, []string{stage.primaryInput}, outputs, stage.sideInputs)
if stage.stateful {
em.StageStateful(stage.ID, stage.stateTypeLen)
}
if stage.onWindowExpiration.TimerFamily != "" {
slog.Debug("OnWindowExpiration", slog.String("stage", stage.ID), slog.Any("values", stage.onWindowExpiration))
em.StageOnWindowExpiration(stage.ID, stage.onWindowExpiration)
}
if len(stage.processingTimeTimers) > 0 {
em.StageProcessingTimeTimers(stage.ID, stage.processingTimeTimers)
}
default:
return fmt.Errorf("unknown environment[%v]", t.GetEnvironmentId())
}
}
// Prime the initial impulses, since we now know what consumes them.
for _, id := range impulses {
em.Impulse(id)
}
// Use an errgroup to limit max parallelism for the pipeline.
eg, egctx := errgroup.WithContext(ctx)
eg.SetLimit(8)
var instID uint64
bundles := em.Bundles(egctx, j.CancelFn, func() string {
return fmt.Sprintf("inst%03d", atomic.AddUint64(&instID, 1))
})
for {
select {
case <-ctx.Done():
err := context.Cause(ctx)
j.Logger.Debug("context canceled", slog.Any("cause", err))
return err
case rb, ok := <-bundles:
if !ok {
err := eg.Wait()
j.Logger.Debug("pipeline done!", slog.String("job", j.String()), slog.Any("error", err), slog.Any("topo", topo))
return err
}
eg.Go(func() error {
s := stages[rb.StageID]
wk := wks[s.envID]
if err := s.Execute(ctx, j, wk, comps, em, rb); err != nil {
// Ensure we clean up on bundle failure
em.FailBundle(rb)
return err
}
return nil
})
}
}
}