func()

in sdks/go/pkg/beam/runners/prism/internal/worker/worker.go [406:590]


func (wk *W) State(state fnpb.BeamFnState_StateServer) error {
	responses := make(chan *fnpb.StateResponse)
	go func() {
		// This go routine creates all responses to state requests from the worker
		// so we want to close the State handler when it's all done.
		defer close(responses)
		for {
			req, err := state.Recv()
			if err == io.EOF {
				return
			}
			if err != nil {
				switch status.Code(err) {
				case codes.Canceled:
					return
				default:
					slog.Error("state.Recv failed", slog.Any("error", err), slog.Any("worker", wk))
					panic(err)
				}
			}

			// State requests are always for an active ProcessBundle instruction
			wk.mu.Lock()
			b, ok := wk.activeInstructions[req.GetInstructionId()].(*B)
			wk.mu.Unlock()
			if !ok {
				slog.Warn("state request after bundle inactive", "instruction", req.GetInstructionId(), "worker", wk)
				continue
			}
			switch req.GetRequest().(type) {
			case *fnpb.StateRequest_Get:
				// TODO: move data handling to be pcollection based.

				key := req.GetStateKey()
				slog.Debug("StateRequest_Get", "request", prototext.Format(req), "bundle", b)
				var data [][]byte
				switch key.GetType().(type) {
				case *fnpb.StateKey_IterableSideInput_:
					ikey := key.GetIterableSideInput()
					wKey := ikey.GetWindow()
					var w typex.Window
					if len(wKey) == 0 {
						w = window.GlobalWindow{}
					} else {
						w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
						if err != nil {
							panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, err))
						}
					}
					winMap := b.IterableSideInputData[SideInputKey{TransformID: ikey.GetTransformId(), Local: ikey.GetSideInputId()}]

					var wins []typex.Window
					for w := range winMap {
						wins = append(wins, w)
					}
					slog.Debug(fmt.Sprintf("side input[%v][%v] I Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, wins))

					data = winMap[w]

				case *fnpb.StateKey_MultimapKeysSideInput_:
					mmkey := key.GetMultimapKeysSideInput()
					wKey := mmkey.GetWindow()
					var w typex.Window = window.GlobalWindow{}
					if len(wKey) > 0 {
						w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
						if err != nil {
							panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, err))
						}
					}
					winMap := b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), Local: mmkey.GetSideInputId()}]
					for k := range winMap[w] {
						data = append(data, []byte(k))
					}

				case *fnpb.StateKey_MultimapSideInput_:
					mmkey := key.GetMultimapSideInput()
					wKey := mmkey.GetWindow()
					var w typex.Window
					if len(wKey) == 0 {
						w = window.GlobalWindow{}
					} else {
						w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
						if err != nil {
							panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, err))
						}
					}
					dKey := mmkey.GetKey()
					winMap := b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), Local: mmkey.GetSideInputId()}]

					slog.Debug(fmt.Sprintf("side input[%v][%v] MultiMap Window: %v", req.GetId(), req.GetInstructionId(), w))

					data = winMap[w][string(dKey)]

				case *fnpb.StateKey_BagUserState_:
					bagkey := key.GetBagUserState()
					data = b.OutputData.GetBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey())
				case *fnpb.StateKey_MultimapUserState_:
					mmkey := key.GetMultimapUserState()
					data = b.OutputData.GetMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey())
				case *fnpb.StateKey_MultimapKeysUserState_:
					mmkey := key.GetMultimapKeysUserState()
					data = b.OutputData.GetMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey())
				case *fnpb.StateKey_OrderedListUserState_:
					olkey := key.GetOrderedListUserState()
					data = b.OutputData.GetOrderedListState(
						engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()},
						olkey.GetWindow(), olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd())
				default:
					panic(fmt.Sprintf("unsupported StateKey Get type: %T: %v", key.GetType(), prototext.Format(key)))
				}

				// Encode the runner iterable (no length, just consecutive elements), and send it out.
				// This is also where we can handle things like State Backed Iterables.
				responses <- &fnpb.StateResponse{
					Id: req.GetId(),
					Response: &fnpb.StateResponse_Get{
						Get: &fnpb.StateGetResponse{
							Data: bytes.Join(data, []byte{}),
						},
					},
				}

			case *fnpb.StateRequest_Append:
				key := req.GetStateKey()
				switch key.GetType().(type) {
				case *fnpb.StateKey_BagUserState_:
					bagkey := key.GetBagUserState()
					b.OutputData.AppendBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey(), req.GetAppend().GetData())
				case *fnpb.StateKey_MultimapUserState_:
					mmkey := key.GetMultimapUserState()
					b.OutputData.AppendMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey(), req.GetAppend().GetData())
				case *fnpb.StateKey_OrderedListUserState_:
					olkey := key.GetOrderedListUserState()
					b.OutputData.AppendOrderedListState(
						engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()},
						olkey.GetWindow(), olkey.GetKey(), req.GetAppend().GetData())
				default:
					panic(fmt.Sprintf("unsupported StateKey Append type: %T: %v", key.GetType(), prototext.Format(key)))
				}

				responses <- &fnpb.StateResponse{
					Id: req.GetId(),
					Response: &fnpb.StateResponse_Append{
						Append: &fnpb.StateAppendResponse{},
					},
				}

			case *fnpb.StateRequest_Clear:
				key := req.GetStateKey()
				switch key.GetType().(type) {
				case *fnpb.StateKey_BagUserState_:
					bagkey := key.GetBagUserState()
					b.OutputData.ClearBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey())
				case *fnpb.StateKey_MultimapUserState_:
					mmkey := key.GetMultimapUserState()
					b.OutputData.ClearMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey())
				case *fnpb.StateKey_MultimapKeysUserState_:
					mmkey := key.GetMultimapUserState()
					b.OutputData.ClearMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey())
				case *fnpb.StateKey_OrderedListUserState_:
					olkey := key.GetOrderedListUserState()
					b.OutputData.ClearOrderedListState(engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()},
						olkey.GetWindow(), olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd())
				default:
					panic(fmt.Sprintf("unsupported StateKey Clear type: %T: %v", key.GetType(), prototext.Format(key)))
				}
				responses <- &fnpb.StateResponse{
					Id: req.GetId(),
					Response: &fnpb.StateResponse_Clear{
						Clear: &fnpb.StateClearResponse{},
					},
				}

			default:
				panic(fmt.Sprintf("unsupported StateRequest kind %T: %v", req.GetRequest(), prototext.Format(req)))
			}
		}
	}()
	for resp := range responses {
		if err := state.Send(resp); err != nil {
			slog.Error("state.Send", slog.Any("error", err))
		}
	}
	return nil
}