in sdks/go/pkg/beam/runners/prism/internal/stage.go [386:664]
func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *engine.ElementManager) (err error) {
// Catch construction time panics and produce them as errors out.
defer func() {
if r := recover(); r != nil {
switch rt := r.(type) {
case error:
err = rt
default:
err = fmt.Errorf("%v", r)
}
}
}()
// Assume stage has an indicated primary input
coders := map[string]*pipepb.Coder{}
transforms := map[string]*pipepb.PTransform{}
pcollections := map[string]*pipepb.PCollection{}
clonePColToBundle := func(pid string) *pipepb.PCollection {
col := proto.Clone(comps.GetPcollections()[pid]).(*pipepb.PCollection)
pcollections[pid] = col
return col
}
// Update coders for Stateful transforms.
for _, tid := range stg.transforms {
t := comps.GetTransforms()[tid]
transforms[tid] = t
if t.GetSpec().GetUrn() != urns.TransformParDo {
continue
}
pardo := &pipepb.ParDoPayload{}
if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil {
return fmt.Errorf("unable to decode ParDoPayload for %v in stage %v", tid, stg.ID)
}
// We need to ensure the coders can be handled by prism, and are available in the bundle descriptor.
// So we rewrite the transform's Payload with updated coder ids here.
var rewrite bool
var rewriteErr error
for stateID, s := range pardo.GetStateSpecs() {
rewrite = true
rewriteCoder := func(cid *string) {
newCid, err := lpUnknownCoders(*cid, coders, comps.GetCoders())
if err != nil {
rewriteErr = fmt.Errorf("unable to rewrite coder %v for state %v for transform %v in stage %v:%w", *cid, stateID, tid, stg.ID, err)
return
}
*cid = newCid
}
switch s := s.GetSpec().(type) {
case *pipepb.StateSpec_BagSpec:
rewriteCoder(&s.BagSpec.ElementCoderId)
case *pipepb.StateSpec_SetSpec:
rewriteCoder(&s.SetSpec.ElementCoderId)
case *pipepb.StateSpec_OrderedListSpec:
rewriteCoder(&s.OrderedListSpec.ElementCoderId)
// Add the length determination helper for OrderedList state values.
if stg.stateTypeLen == nil {
stg.stateTypeLen = map[engine.LinkID]func([]byte) int{}
}
linkID := engine.LinkID{
Transform: tid,
Local: stateID,
}
var fn func([]byte) int
switch v := coders[s.OrderedListSpec.GetElementCoderId()]; v.GetSpec().GetUrn() {
case urns.CoderBool:
fn = func(_ []byte) int {
return 1
}
case urns.CoderDouble:
fn = func(_ []byte) int {
return 8
}
case urns.CoderVarInt:
fn = func(b []byte) int {
_, n := protowire.ConsumeVarint(b)
return int(n)
}
case urns.CoderLengthPrefix, urns.CoderBytes, urns.CoderStringUTF8:
fn = func(b []byte) int {
l, n := protowire.ConsumeVarint(b)
return int(l) + n
}
default:
rewriteErr = fmt.Errorf("unknown coder used for ordered list state after re-write id: %v coder: %v, for state %v for transform %v in stage %v", s.OrderedListSpec.GetElementCoderId(), v, stateID, tid, stg.ID)
}
stg.stateTypeLen[linkID] = fn
case *pipepb.StateSpec_CombiningSpec:
rewriteCoder(&s.CombiningSpec.AccumulatorCoderId)
case *pipepb.StateSpec_MapSpec:
rewriteCoder(&s.MapSpec.KeyCoderId)
rewriteCoder(&s.MapSpec.ValueCoderId)
case *pipepb.StateSpec_MultimapSpec:
rewriteCoder(&s.MultimapSpec.KeyCoderId)
rewriteCoder(&s.MultimapSpec.ValueCoderId)
case *pipepb.StateSpec_ReadModifyWriteSpec:
rewriteCoder(&s.ReadModifyWriteSpec.CoderId)
}
if rewriteErr != nil {
return rewriteErr
}
}
for timerID, v := range pardo.GetTimerFamilySpecs() {
stg.hasTimers = append(stg.hasTimers, engine.StaticTimerID{TransformID: tid, TimerFamily: timerID})
if v.TimeDomain == pipepb.TimeDomain_PROCESSING_TIME {
if stg.processingTimeTimers == nil {
stg.processingTimeTimers = map[string]bool{}
}
stg.processingTimeTimers[timerID] = true
}
rewrite = true
newCid, err := lpUnknownCoders(v.GetTimerFamilyCoderId(), coders, comps.GetCoders())
if err != nil {
return fmt.Errorf("unable to rewrite coder %v for timer %v for transform %v in stage %v: %w", v.GetTimerFamilyCoderId(), timerID, tid, stg.ID, err)
}
v.TimerFamilyCoderId = newCid
}
if rewrite {
pyld, err := proto.MarshalOptions{}.Marshal(pardo)
if err != nil {
return fmt.Errorf("unable to encode ParDoPayload for %v in stage %v after rewrite", tid, stg.ID)
}
t.Spec.Payload = pyld
}
}
if len(transforms) == 0 {
return fmt.Errorf("buildDescriptor: invalid stage - no transforms at all %v", stg.ID)
}
// Start with outputs, since they're simple and uniform.
sink2Col := map[string]string{}
col2Coders := map[string]engine.PColInfo{}
for _, o := range stg.outputs {
col := clonePColToBundle(o.Global)
wOutCid, err := makeWindowedValueCoder(o.Global, comps, coders)
if err != nil {
return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for output %+v, pcol %q %v:\n%w %v", stg.ID, o, o.Global, prototext.Format(col), err, stg.transforms)
}
sinkID := o.Transform + "_" + o.Local
ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
var kd func(io.Reader) []byte
if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
kd = collectionPullDecoder(kcid, coders, comps)
}
winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
sink2Col[sinkID] = o.Global
col2Coders[o.Global] = engine.PColInfo{
GlobalID: o.Global,
WindowCoder: winCoder,
WDec: wDec,
WEnc: wEnc,
EDec: ed,
KeyDec: kd,
}
transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, wk), o.Global)
}
var prepareSides []func(b *worker.B, watermark mtime.Time)
for _, si := range stg.sideInputs {
col := clonePColToBundle(si.Global)
oCID := col.GetCoderId()
nCID, err := lpUnknownCoders(oCID, coders, comps.GetCoders())
if err != nil {
return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for side input %+v, pcol %q %v:\n%w", stg.ID, si, si.Global, prototext.Format(col), err)
}
if oCID != nCID {
// Add a synthetic PCollection set with the new coder.
newGlobal := si.Global + "_prismside"
pcollections[newGlobal] = &pipepb.PCollection{
DisplayData: col.GetDisplayData(),
UniqueName: col.GetUniqueName(),
CoderId: nCID,
IsBounded: col.GetIsBounded(),
WindowingStrategyId: col.WindowingStrategyId,
}
// Update side inputs to point to new PCollection with any replaced coders.
transforms[si.Transform].GetInputs()[si.Local] = newGlobal
// TODO: replace si.Global with newGlobal?
}
prepSide, err := handleSideInput(si, comps, transforms, pcollections, coders, em)
if err != nil {
slog.Error("buildDescriptor: handleSideInputs", "error", err, slog.String("transformID", si.Transform))
return err
}
prepareSides = append(prepareSides, prepSide)
}
// Finally, the parallel input, which is it's own special snowflake, that needs a datasource.
// This id is directly used for the source, but this also copies
// coders used by side inputs to the coders map for the bundle, so
// needs to be run for every ID.
col := clonePColToBundle(stg.primaryInput)
if newCID, err := lpUnknownCoders(col.GetCoderId(), coders, comps.GetCoders()); err == nil && col.GetCoderId() != newCID {
col.CoderId = newCID
} else if err != nil {
return fmt.Errorf("buildDescriptor: couldn't rewrite coder %q for primary input pcollection %q: %w", col.GetCoderId(), stg.primaryInput, err)
}
wInCid, err := makeWindowedValueCoder(stg.primaryInput, comps, coders)
if err != nil {
return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for primary input, pcol %q %v:\n%w\n%v", stg.ID, stg.primaryInput, prototext.Format(col), err, stg.transforms)
}
ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
var kd func(io.Reader) []byte
if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
kd = collectionPullDecoder(kcid, coders, comps)
}
inputInfo := engine.PColInfo{
GlobalID: stg.primaryInput,
WindowCoder: winCoder,
WDec: wDec,
WEnc: wEnc,
EDec: ed,
KeyDec: kd,
}
stg.inputTransformID = stg.ID + "_source"
transforms[stg.inputTransformID] = sourceTransform(stg.inputTransformID, portFor(wInCid, wk), stg.primaryInput)
// Update coders for internal collections, and add those collections to the bundle descriptor.
for _, pid := range stg.internalCols {
col := clonePColToBundle(pid)
// Keep the original coder of an internal pcollection without rewriting(LP'ing).
if err := retrieveCoders(col.GetCoderId(), coders, comps.GetCoders()); err != nil {
return fmt.Errorf("buildDescriptor: couldn't retrieve coder %q for internal pcollection %q: %w", col.GetCoderId(), pid, err)
}
}
// Add coders for all windowing strategies.
// TODO: filter PCollections, filter windowing strategies by Pcollections instead.
for _, ws := range comps.GetWindowingStrategies() {
lpUnknownCoders(ws.GetWindowCoderId(), coders, comps.GetCoders())
}
reconcileCoders(coders, comps.GetCoders())
var timerServiceDescriptor *pipepb.ApiServiceDescriptor
if len(stg.hasTimers) > 0 {
timerServiceDescriptor = &pipepb.ApiServiceDescriptor{
Url: wk.Endpoint(),
}
}
desc := &fnpb.ProcessBundleDescriptor{
Id: stg.ID,
Transforms: transforms,
WindowingStrategies: comps.GetWindowingStrategies(),
Pcollections: pcollections,
Coders: coders,
StateApiServiceDescriptor: &pipepb.ApiServiceDescriptor{
Url: wk.Endpoint(),
},
TimerApiServiceDescriptor: timerServiceDescriptor,
}
stg.desc = desc
stg.prepareSides = func(b *worker.B, watermark mtime.Time) {
for _, prep := range prepareSides {
prep(b, watermark)
}
}
stg.SinkToPCollection = sink2Col
stg.OutputsToCoders = col2Coders
stg.inputInfo = inputInfo
wk.Descriptors[stg.ID] = stg.desc
return nil
}