func buildDescriptor()

in sdks/go/pkg/beam/runners/prism/internal/stage.go [386:664]


func buildDescriptor(stg *stage, comps *pipepb.Components, wk *worker.W, em *engine.ElementManager) (err error) {
	// Catch construction time panics and produce them as errors out.
	defer func() {
		if r := recover(); r != nil {
			switch rt := r.(type) {
			case error:
				err = rt
			default:
				err = fmt.Errorf("%v", r)
			}
		}
	}()
	// Assume stage has an indicated primary input

	coders := map[string]*pipepb.Coder{}
	transforms := map[string]*pipepb.PTransform{}
	pcollections := map[string]*pipepb.PCollection{}

	clonePColToBundle := func(pid string) *pipepb.PCollection {
		col := proto.Clone(comps.GetPcollections()[pid]).(*pipepb.PCollection)
		pcollections[pid] = col
		return col
	}

	// Update coders for Stateful transforms.
	for _, tid := range stg.transforms {
		t := comps.GetTransforms()[tid]

		transforms[tid] = t

		if t.GetSpec().GetUrn() != urns.TransformParDo {
			continue
		}

		pardo := &pipepb.ParDoPayload{}
		if err := (proto.UnmarshalOptions{}).Unmarshal(t.GetSpec().GetPayload(), pardo); err != nil {
			return fmt.Errorf("unable to decode ParDoPayload for %v in stage %v", tid, stg.ID)
		}

		// We need to ensure the coders can be handled by prism, and are available in the bundle descriptor.
		// So we rewrite the transform's Payload with updated coder ids here.
		var rewrite bool
		var rewriteErr error
		for stateID, s := range pardo.GetStateSpecs() {
			rewrite = true
			rewriteCoder := func(cid *string) {
				newCid, err := lpUnknownCoders(*cid, coders, comps.GetCoders())
				if err != nil {
					rewriteErr = fmt.Errorf("unable to rewrite coder %v for state %v for transform %v in stage %v:%w", *cid, stateID, tid, stg.ID, err)
					return
				}
				*cid = newCid
			}
			switch s := s.GetSpec().(type) {
			case *pipepb.StateSpec_BagSpec:
				rewriteCoder(&s.BagSpec.ElementCoderId)
			case *pipepb.StateSpec_SetSpec:
				rewriteCoder(&s.SetSpec.ElementCoderId)
			case *pipepb.StateSpec_OrderedListSpec:
				rewriteCoder(&s.OrderedListSpec.ElementCoderId)
				// Add the length determination helper for OrderedList state values.
				if stg.stateTypeLen == nil {
					stg.stateTypeLen = map[engine.LinkID]func([]byte) int{}
				}
				linkID := engine.LinkID{
					Transform: tid,
					Local:     stateID,
				}
				var fn func([]byte) int
				switch v := coders[s.OrderedListSpec.GetElementCoderId()]; v.GetSpec().GetUrn() {
				case urns.CoderBool:
					fn = func(_ []byte) int {
						return 1
					}
				case urns.CoderDouble:
					fn = func(_ []byte) int {
						return 8
					}
				case urns.CoderVarInt:
					fn = func(b []byte) int {
						_, n := protowire.ConsumeVarint(b)
						return int(n)
					}
				case urns.CoderLengthPrefix, urns.CoderBytes, urns.CoderStringUTF8:
					fn = func(b []byte) int {
						l, n := protowire.ConsumeVarint(b)
						return int(l) + n
					}
				default:
					rewriteErr = fmt.Errorf("unknown coder used for ordered list state after re-write id: %v coder: %v, for state %v for transform %v in stage %v", s.OrderedListSpec.GetElementCoderId(), v, stateID, tid, stg.ID)
				}
				stg.stateTypeLen[linkID] = fn
			case *pipepb.StateSpec_CombiningSpec:
				rewriteCoder(&s.CombiningSpec.AccumulatorCoderId)
			case *pipepb.StateSpec_MapSpec:
				rewriteCoder(&s.MapSpec.KeyCoderId)
				rewriteCoder(&s.MapSpec.ValueCoderId)
			case *pipepb.StateSpec_MultimapSpec:
				rewriteCoder(&s.MultimapSpec.KeyCoderId)
				rewriteCoder(&s.MultimapSpec.ValueCoderId)
			case *pipepb.StateSpec_ReadModifyWriteSpec:
				rewriteCoder(&s.ReadModifyWriteSpec.CoderId)
			}
			if rewriteErr != nil {
				return rewriteErr
			}
		}
		for timerID, v := range pardo.GetTimerFamilySpecs() {
			stg.hasTimers = append(stg.hasTimers, engine.StaticTimerID{TransformID: tid, TimerFamily: timerID})
			if v.TimeDomain == pipepb.TimeDomain_PROCESSING_TIME {
				if stg.processingTimeTimers == nil {
					stg.processingTimeTimers = map[string]bool{}
				}
				stg.processingTimeTimers[timerID] = true
			}
			rewrite = true
			newCid, err := lpUnknownCoders(v.GetTimerFamilyCoderId(), coders, comps.GetCoders())
			if err != nil {
				return fmt.Errorf("unable to rewrite coder %v for timer %v for transform %v in stage %v: %w", v.GetTimerFamilyCoderId(), timerID, tid, stg.ID, err)
			}
			v.TimerFamilyCoderId = newCid
		}
		if rewrite {
			pyld, err := proto.MarshalOptions{}.Marshal(pardo)
			if err != nil {
				return fmt.Errorf("unable to encode ParDoPayload for %v in stage %v after rewrite", tid, stg.ID)
			}
			t.Spec.Payload = pyld
		}
	}
	if len(transforms) == 0 {
		return fmt.Errorf("buildDescriptor: invalid stage - no transforms at all %v", stg.ID)
	}

	// Start with outputs, since they're simple and uniform.
	sink2Col := map[string]string{}
	col2Coders := map[string]engine.PColInfo{}
	for _, o := range stg.outputs {
		col := clonePColToBundle(o.Global)
		wOutCid, err := makeWindowedValueCoder(o.Global, comps, coders)
		if err != nil {
			return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for output %+v, pcol %q %v:\n%w %v", stg.ID, o, o.Global, prototext.Format(col), err, stg.transforms)
		}
		sinkID := o.Transform + "_" + o.Local
		ed := collectionPullDecoder(col.GetCoderId(), coders, comps)

		var kd func(io.Reader) []byte
		if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
			kd = collectionPullDecoder(kcid, coders, comps)
		}

		winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)
		sink2Col[sinkID] = o.Global
		col2Coders[o.Global] = engine.PColInfo{
			GlobalID:    o.Global,
			WindowCoder: winCoder,
			WDec:        wDec,
			WEnc:        wEnc,
			EDec:        ed,
			KeyDec:      kd,
		}
		transforms[sinkID] = sinkTransform(sinkID, portFor(wOutCid, wk), o.Global)
	}

	var prepareSides []func(b *worker.B, watermark mtime.Time)
	for _, si := range stg.sideInputs {
		col := clonePColToBundle(si.Global)

		oCID := col.GetCoderId()
		nCID, err := lpUnknownCoders(oCID, coders, comps.GetCoders())
		if err != nil {
			return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for side input %+v, pcol %q %v:\n%w", stg.ID, si, si.Global, prototext.Format(col), err)
		}
		if oCID != nCID {
			// Add a synthetic PCollection set with the new coder.
			newGlobal := si.Global + "_prismside"
			pcollections[newGlobal] = &pipepb.PCollection{
				DisplayData:         col.GetDisplayData(),
				UniqueName:          col.GetUniqueName(),
				CoderId:             nCID,
				IsBounded:           col.GetIsBounded(),
				WindowingStrategyId: col.WindowingStrategyId,
			}
			// Update side inputs to point to new PCollection with any replaced coders.
			transforms[si.Transform].GetInputs()[si.Local] = newGlobal
			// TODO: replace si.Global with newGlobal?
		}
		prepSide, err := handleSideInput(si, comps, transforms, pcollections, coders, em)
		if err != nil {
			slog.Error("buildDescriptor: handleSideInputs", "error", err, slog.String("transformID", si.Transform))
			return err
		}
		prepareSides = append(prepareSides, prepSide)
	}

	// Finally, the parallel input, which is it's own special snowflake, that needs a datasource.
	// This id is directly used for the source, but this also copies
	// coders used by side inputs to the coders map for the bundle, so
	// needs to be run for every ID.

	col := clonePColToBundle(stg.primaryInput)
	if newCID, err := lpUnknownCoders(col.GetCoderId(), coders, comps.GetCoders()); err == nil && col.GetCoderId() != newCID {
		col.CoderId = newCID
	} else if err != nil {
		return fmt.Errorf("buildDescriptor: couldn't rewrite coder %q for primary input pcollection %q: %w", col.GetCoderId(), stg.primaryInput, err)
	}

	wInCid, err := makeWindowedValueCoder(stg.primaryInput, comps, coders)
	if err != nil {
		return fmt.Errorf("buildDescriptor: failed to handle coder on stage %v for primary input, pcol %q %v:\n%w\n%v", stg.ID, stg.primaryInput, prototext.Format(col), err, stg.transforms)
	}
	ed := collectionPullDecoder(col.GetCoderId(), coders, comps)
	winCoder, wDec, wEnc := getWindowValueCoders(comps, col, coders)

	var kd func(io.Reader) []byte
	if kcid, ok := extractKVCoderID(col.GetCoderId(), coders); ok {
		kd = collectionPullDecoder(kcid, coders, comps)
	}

	inputInfo := engine.PColInfo{
		GlobalID:    stg.primaryInput,
		WindowCoder: winCoder,
		WDec:        wDec,
		WEnc:        wEnc,
		EDec:        ed,
		KeyDec:      kd,
	}

	stg.inputTransformID = stg.ID + "_source"
	transforms[stg.inputTransformID] = sourceTransform(stg.inputTransformID, portFor(wInCid, wk), stg.primaryInput)

	// Update coders for internal collections, and add those collections to the bundle descriptor.
	for _, pid := range stg.internalCols {
		col := clonePColToBundle(pid)
		// Keep the original coder of an internal pcollection without rewriting(LP'ing).
		if err := retrieveCoders(col.GetCoderId(), coders, comps.GetCoders()); err != nil {
			return fmt.Errorf("buildDescriptor: couldn't retrieve coder %q for internal pcollection %q: %w", col.GetCoderId(), pid, err)
		}
	}
	// Add coders for all windowing strategies.
	// TODO: filter PCollections, filter windowing strategies by Pcollections instead.
	for _, ws := range comps.GetWindowingStrategies() {
		lpUnknownCoders(ws.GetWindowCoderId(), coders, comps.GetCoders())
	}

	reconcileCoders(coders, comps.GetCoders())

	var timerServiceDescriptor *pipepb.ApiServiceDescriptor
	if len(stg.hasTimers) > 0 {
		timerServiceDescriptor = &pipepb.ApiServiceDescriptor{
			Url: wk.Endpoint(),
		}
	}

	desc := &fnpb.ProcessBundleDescriptor{
		Id:                  stg.ID,
		Transforms:          transforms,
		WindowingStrategies: comps.GetWindowingStrategies(),
		Pcollections:        pcollections,
		Coders:              coders,
		StateApiServiceDescriptor: &pipepb.ApiServiceDescriptor{
			Url: wk.Endpoint(),
		},
		TimerApiServiceDescriptor: timerServiceDescriptor,
	}

	stg.desc = desc
	stg.prepareSides = func(b *worker.B, watermark mtime.Time) {
		for _, prep := range prepareSides {
			prep(b, watermark)
		}
	}
	stg.SinkToPCollection = sink2Col
	stg.OutputsToCoders = col2Coders
	stg.inputInfo = inputInfo

	wk.Descriptors[stg.ID] = stg.desc
	return nil
}