in sdks/go/pkg/beam/core/runtime/graphx/translate.go [364:723]
func (m *marshaller) addMultiEdge(edge NamedEdge) ([]string, error) {
handleErr := func(err error) ([]string, error) {
return nil, errors.Wrapf(err, "failed to add input kind: %v", edge)
}
id := edgeID(edge.Edge)
if _, exists := m.transforms[id]; exists {
return []string{id}, nil
}
switch {
case edge.Edge.Op == graph.CoGBK && len(edge.Edge.Input) > 1:
cogbkID, err := m.expandCoGBK(edge)
if err != nil {
return handleErr(err)
}
return []string{cogbkID}, nil
case edge.Edge.Op == graph.Reshuffle:
reshuffleID, err := m.expandReshuffle(edge)
if err != nil {
return handleErr(err)
}
return []string{reshuffleID}, nil
case edge.Edge.Op == graph.External:
if edge.Edge.External != nil {
if edge.Edge.External.Expanded != nil {
m.needsExpansion = true
}
}
if edge.Edge.Payload == nil {
edgeID, err := m.expandCrossLanguage(edge)
if err != nil {
return handleErr(err)
}
return []string{edgeID}, nil
}
}
inputs := make(map[string]string)
for i, in := range edge.Edge.Input {
if _, err := m.addNode(in.From); err != nil {
return handleErr(err)
}
inputs[fmt.Sprintf("i%v", i)] = nodeID(in.From)
}
outputs := make(map[string]string)
for i, out := range edge.Edge.Output {
if _, err := m.addNode(out.To); err != nil {
return handleErr(err)
}
outputs[fmt.Sprintf("i%v", i)] = nodeID(out.To)
}
var annotations map[string][]byte
// allPIds tracks additional PTransformIDs generated for the pipeline
var allPIds []string
var spec *pipepb.FunctionSpec
switch edge.Edge.Op {
case graph.Impulse:
spec = &pipepb.FunctionSpec{Urn: URNImpulse}
case graph.ParDo:
si := make(map[string]*pipepb.SideInput)
for i, in := range edge.Edge.Input {
switch in.Kind {
case graph.Main:
// ignore: not a side input
case graph.Singleton, graph.Slice, graph.Iter, graph.ReIter:
siWfn := in.From.WindowingStrategy().Fn
mappingUrn := getSideWindowMappingUrn(siWfn)
siWSpec, err := makeWindowFn(siWfn)
if err != nil {
return nil, err
}
si[fmt.Sprintf("i%v", i)] = &pipepb.SideInput{
AccessPattern: &pipepb.FunctionSpec{
Urn: URNIterableSideInput,
},
ViewFn: &pipepb.FunctionSpec{
Urn: "foo",
},
WindowMappingFn: &pipepb.FunctionSpec{
Urn: mappingUrn,
Payload: siWSpec.Payload,
},
}
case graph.Map, graph.MultiMap:
// Already in a MultiMap form, don't need to add a fixed key.
// Get window mapping, arrange proto field.
siWfn := in.From.WindowingStrategy().Fn
mappingUrn := getSideWindowMappingUrn(siWfn)
siWSpec, err := makeWindowFn(siWfn)
if err != nil {
return nil, err
}
si[fmt.Sprintf("i%v", i)] = &pipepb.SideInput{
AccessPattern: &pipepb.FunctionSpec{
Urn: URNMultimapSideInput,
},
ViewFn: &pipepb.FunctionSpec{
Urn: "foo",
},
WindowMappingFn: &pipepb.FunctionSpec{
Urn: mappingUrn,
Payload: siWSpec.Payload,
},
}
default:
return nil, errors.Errorf("unexpected input kind: %v", edge)
}
}
mustEncodeMultiEdge, err := mustEncodeMultiEdgeBase64(edge.Edge)
if err != nil {
return handleErr(err)
}
payload := &pipepb.ParDoPayload{
DoFn: &pipepb.FunctionSpec{
Urn: URNDoFn,
Payload: []byte(mustEncodeMultiEdge),
},
SideInputs: si,
}
if edge.Edge.DoFn.IsSplittable() {
coderID, err := m.coders.Add(edge.Edge.RestrictionCoder)
if err != nil {
return handleErr(err)
}
payload.RestrictionCoderId = coderID
m.requirements[URNRequiresSplittableDoFn] = true
}
if _, ok := edge.Edge.DoFn.ProcessElementFn().BundleFinalization(); ok {
payload.RequestsFinalization = true
m.requirements[URNRequiresBundleFinalization] = true
}
if _, ok := edge.Edge.DoFn.ProcessElementFn().StateProvider(); ok {
m.requirements[URNRequiresStatefulProcessing] = true
stateSpecs := make(map[string]*pipepb.StateSpec)
for _, ps := range edge.Edge.DoFn.PipelineState() {
coderID := ""
c, ok := edge.Edge.StateCoders[UserStateCoderID(ps)]
if ok {
coderID, err = m.coders.Add(c)
if err != nil {
return handleErr(err)
}
}
keyCoderID := ""
if c, ok := edge.Edge.StateCoders[UserStateKeyCoderID(ps)]; ok {
keyCoderID, err = m.coders.Add(c)
if err != nil {
return handleErr(err)
}
} else if ps.StateType() == state.TypeMap || ps.StateType() == state.TypeSet {
return nil, errors.Errorf("set or map state type %v must have a key coder type, none detected", ps)
}
switch ps.StateType() {
case state.TypeValue:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_ReadModifyWriteSpec{
ReadModifyWriteSpec: &pipepb.ReadModifyWriteStateSpec{
CoderId: coderID,
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNBagUserState,
},
}
case state.TypeBag:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_BagSpec{
BagSpec: &pipepb.BagStateSpec{
ElementCoderId: coderID,
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNBagUserState,
},
}
case state.TypeCombining:
cps := ps.(state.CombiningPipelineState).GetCombineFn()
f, err := graph.NewFn(cps)
if err != nil {
return handleErr(err)
}
cf, err := graph.AsCombineFn(f)
if err != nil {
return handleErr(err)
}
me := graph.MultiEdge{
Op: graph.Combine,
CombineFn: cf,
}
mustEncodeMultiEdge, err := mustEncodeMultiEdgeBase64(&me)
if err != nil {
return handleErr(err)
}
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_CombiningSpec{
CombiningSpec: &pipepb.CombiningStateSpec{
AccumulatorCoderId: coderID,
CombineFn: &pipepb.FunctionSpec{
Urn: "beam:combinefn:gosdk:v1",
Payload: []byte(mustEncodeMultiEdge),
},
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNBagUserState,
},
}
case state.TypeMap:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_MapSpec{
MapSpec: &pipepb.MapStateSpec{
KeyCoderId: keyCoderID,
ValueCoderId: coderID,
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNMultiMapUserState,
},
}
case state.TypeSet:
stateSpecs[ps.StateKey()] = &pipepb.StateSpec{
Spec: &pipepb.StateSpec_SetSpec{
SetSpec: &pipepb.SetStateSpec{
ElementCoderId: keyCoderID,
},
},
Protocol: &pipepb.FunctionSpec{
Urn: URNMultiMapUserState,
},
}
default:
return nil, errors.Errorf("State type %v not recognized for state %v", ps.StateKey(), ps)
}
}
payload.StateSpecs = stateSpecs
}
if _, ok := edge.Edge.DoFn.ProcessElementFn().TimerProvider(); ok {
m.requirements[URNRequiresStatefulProcessing] = true
timerSpecs := make(map[string]*pipepb.TimerFamilySpec)
pipelineTimers, _ := edge.Edge.DoFn.PipelineTimers()
// All timers for a single DoFn have the same key and window coders, that match the input PCollection.
mainInputID := inputs["i0"]
pCol := m.pcollections[mainInputID]
kvCoder := m.coders.coders[pCol.CoderId]
if kvCoder.GetSpec().GetUrn() != urnKVCoder {
return nil, errors.Errorf("timer using DoFn %v doesn't use a KV as PCollection input. Unable to extract key coder for timers, got %v", edge.Name, kvCoder.GetSpec().GetUrn())
}
keyCoderID := kvCoder.GetComponentCoderIds()[0]
wsID := pCol.GetWindowingStrategyId()
ws := m.windowing[wsID]
windowCoderID := ws.GetWindowCoderId()
timerCoderID := m.coders.internBuiltInCoder(urnTimerCoder, keyCoderID, windowCoderID)
for _, pt := range pipelineTimers {
for timerFamilyID, timeDomain := range pt.Timers() {
timerSpecs[timerFamilyID] = &pipepb.TimerFamilySpec{
TimeDomain: pipepb.TimeDomain_Enum(timeDomain),
TimerFamilyCoderId: timerCoderID,
}
}
}
payload.TimerFamilySpecs = timerSpecs
}
spec = &pipepb.FunctionSpec{Urn: URNParDo, Payload: protox.MustEncode(payload)}
annotations = edge.Edge.DoFn.Annotations()
case graph.Combine:
mustEncodeMultiEdge, err := mustEncodeMultiEdgeBase64(edge.Edge)
if err != nil {
return handleErr(err)
}
payload := &pipepb.ParDoPayload{
DoFn: &pipepb.FunctionSpec{
Urn: URNDoFn,
Payload: []byte(mustEncodeMultiEdge),
},
}
spec = &pipepb.FunctionSpec{Urn: URNParDo, Payload: protox.MustEncode(payload)}
case graph.Flatten:
spec = &pipepb.FunctionSpec{Urn: URNFlatten}
case graph.CoGBK:
spec = &pipepb.FunctionSpec{Urn: URNGBK}
case graph.WindowInto:
windowFn, err := makeWindowFn(edge.Edge.WindowFn)
if err != nil {
return handleErr(err)
}
payload := &pipepb.WindowIntoPayload{
WindowFn: windowFn,
}
spec = &pipepb.FunctionSpec{Urn: URNWindow, Payload: protox.MustEncode(payload)}
case graph.External:
pyld := edge.Edge.Payload
spec = &pipepb.FunctionSpec{Urn: pyld.URN, Payload: pyld.Data}
if len(pyld.InputsMap) != 0 {
if got, want := len(pyld.InputsMap), len(edge.Edge.Input); got != want {
return handleErr(errors.Errorf("mismatch'd counts between External tags (%v) and inputs (%v)", got, want))
}
inputs = make(map[string]string)
for tag, in := range InboundTagToNode(pyld.InputsMap, edge.Edge.Input) {
if _, err := m.addNode(in); err != nil {
return handleErr(err)
}
inputs[tag] = nodeID(in)
}
}
if len(pyld.OutputsMap) != 0 {
if got, want := len(pyld.OutputsMap), len(edge.Edge.Output); got != want {
return handleErr(errors.Errorf("mismatch'd counts between External tags (%v) and outputs (%v)", got, want))
}
outputs = make(map[string]string)
for tag, out := range OutboundTagToNode(pyld.OutputsMap, edge.Edge.Output) {
if _, err := m.addNode(out); err != nil {
return handleErr(err)
}
outputs[tag] = nodeID(out)
}
}
default:
err := errors.Errorf("unexpected opcode: %v", edge.Edge.Op)
return handleErr(err)
}
var transformEnvID = ""
if !(spec.Urn == URNGBK || spec.Urn == URNImpulse) {
transformEnvID = m.addDefaultEnv()
}
transform := &pipepb.PTransform{
UniqueName: edge.Name,
Spec: spec,
Inputs: inputs,
Outputs: outputs,
EnvironmentId: transformEnvID,
Annotations: annotations,
}
m.transforms[id] = transform
allPIds = append(allPIds, id)
return allPIds, nil
}