in sdks/go/pkg/beam/runners/prism/internal/worker/worker.go [406:590]
func (wk *W) State(state fnpb.BeamFnState_StateServer) error {
responses := make(chan *fnpb.StateResponse)
go func() {
// This go routine creates all responses to state requests from the worker
// so we want to close the State handler when it's all done.
defer close(responses)
for {
req, err := state.Recv()
if err == io.EOF {
return
}
if err != nil {
switch status.Code(err) {
case codes.Canceled:
return
default:
slog.Error("state.Recv failed", slog.Any("error", err), slog.Any("worker", wk))
panic(err)
}
}
// State requests are always for an active ProcessBundle instruction
wk.mu.Lock()
b, ok := wk.activeInstructions[req.GetInstructionId()].(*B)
wk.mu.Unlock()
if !ok {
slog.Warn("state request after bundle inactive", "instruction", req.GetInstructionId(), "worker", wk)
continue
}
switch req.GetRequest().(type) {
case *fnpb.StateRequest_Get:
// TODO: move data handling to be pcollection based.
key := req.GetStateKey()
slog.Debug("StateRequest_Get", "request", prototext.Format(req), "bundle", b)
var data [][]byte
switch key.GetType().(type) {
case *fnpb.StateKey_IterableSideInput_:
ikey := key.GetIterableSideInput()
wKey := ikey.GetWindow()
var w typex.Window
if len(wKey) == 0 {
w = window.GlobalWindow{}
} else {
w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
if err != nil {
panic(fmt.Sprintf("error decoding iterable side input window key %v: %v", wKey, err))
}
}
winMap := b.IterableSideInputData[SideInputKey{TransformID: ikey.GetTransformId(), Local: ikey.GetSideInputId()}]
var wins []typex.Window
for w := range winMap {
wins = append(wins, w)
}
slog.Debug(fmt.Sprintf("side input[%v][%v] I Key: %v Windows: %v", req.GetId(), req.GetInstructionId(), w, wins))
data = winMap[w]
case *fnpb.StateKey_MultimapKeysSideInput_:
mmkey := key.GetMultimapKeysSideInput()
wKey := mmkey.GetWindow()
var w typex.Window = window.GlobalWindow{}
if len(wKey) > 0 {
w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
if err != nil {
panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, err))
}
}
winMap := b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), Local: mmkey.GetSideInputId()}]
for k := range winMap[w] {
data = append(data, []byte(k))
}
case *fnpb.StateKey_MultimapSideInput_:
mmkey := key.GetMultimapSideInput()
wKey := mmkey.GetWindow()
var w typex.Window
if len(wKey) == 0 {
w = window.GlobalWindow{}
} else {
w, err = exec.MakeWindowDecoder(coder.NewIntervalWindow()).DecodeSingle(bytes.NewBuffer(wKey))
if err != nil {
panic(fmt.Sprintf("error decoding multimap side input window key %v: %v", wKey, err))
}
}
dKey := mmkey.GetKey()
winMap := b.MultiMapSideInputData[SideInputKey{TransformID: mmkey.GetTransformId(), Local: mmkey.GetSideInputId()}]
slog.Debug(fmt.Sprintf("side input[%v][%v] MultiMap Window: %v", req.GetId(), req.GetInstructionId(), w))
data = winMap[w][string(dKey)]
case *fnpb.StateKey_BagUserState_:
bagkey := key.GetBagUserState()
data = b.OutputData.GetBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey())
case *fnpb.StateKey_MultimapUserState_:
mmkey := key.GetMultimapUserState()
data = b.OutputData.GetMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey())
case *fnpb.StateKey_MultimapKeysUserState_:
mmkey := key.GetMultimapKeysUserState()
data = b.OutputData.GetMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey())
case *fnpb.StateKey_OrderedListUserState_:
olkey := key.GetOrderedListUserState()
data = b.OutputData.GetOrderedListState(
engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()},
olkey.GetWindow(), olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd())
default:
panic(fmt.Sprintf("unsupported StateKey Get type: %T: %v", key.GetType(), prototext.Format(key)))
}
// Encode the runner iterable (no length, just consecutive elements), and send it out.
// This is also where we can handle things like State Backed Iterables.
responses <- &fnpb.StateResponse{
Id: req.GetId(),
Response: &fnpb.StateResponse_Get{
Get: &fnpb.StateGetResponse{
Data: bytes.Join(data, []byte{}),
},
},
}
case *fnpb.StateRequest_Append:
key := req.GetStateKey()
switch key.GetType().(type) {
case *fnpb.StateKey_BagUserState_:
bagkey := key.GetBagUserState()
b.OutputData.AppendBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey(), req.GetAppend().GetData())
case *fnpb.StateKey_MultimapUserState_:
mmkey := key.GetMultimapUserState()
b.OutputData.AppendMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey(), req.GetAppend().GetData())
case *fnpb.StateKey_OrderedListUserState_:
olkey := key.GetOrderedListUserState()
b.OutputData.AppendOrderedListState(
engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()},
olkey.GetWindow(), olkey.GetKey(), req.GetAppend().GetData())
default:
panic(fmt.Sprintf("unsupported StateKey Append type: %T: %v", key.GetType(), prototext.Format(key)))
}
responses <- &fnpb.StateResponse{
Id: req.GetId(),
Response: &fnpb.StateResponse_Append{
Append: &fnpb.StateAppendResponse{},
},
}
case *fnpb.StateRequest_Clear:
key := req.GetStateKey()
switch key.GetType().(type) {
case *fnpb.StateKey_BagUserState_:
bagkey := key.GetBagUserState()
b.OutputData.ClearBagState(engine.LinkID{Transform: bagkey.GetTransformId(), Local: bagkey.GetUserStateId()}, bagkey.GetWindow(), bagkey.GetKey())
case *fnpb.StateKey_MultimapUserState_:
mmkey := key.GetMultimapUserState()
b.OutputData.ClearMultimapState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey(), mmkey.GetMapKey())
case *fnpb.StateKey_MultimapKeysUserState_:
mmkey := key.GetMultimapUserState()
b.OutputData.ClearMultimapKeysState(engine.LinkID{Transform: mmkey.GetTransformId(), Local: mmkey.GetUserStateId()}, mmkey.GetWindow(), mmkey.GetKey())
case *fnpb.StateKey_OrderedListUserState_:
olkey := key.GetOrderedListUserState()
b.OutputData.ClearOrderedListState(engine.LinkID{Transform: olkey.GetTransformId(), Local: olkey.GetUserStateId()},
olkey.GetWindow(), olkey.GetKey(), olkey.GetRange().GetStart(), olkey.GetRange().GetEnd())
default:
panic(fmt.Sprintf("unsupported StateKey Clear type: %T: %v", key.GetType(), prototext.Format(key)))
}
responses <- &fnpb.StateResponse{
Id: req.GetId(),
Response: &fnpb.StateResponse_Clear{
Clear: &fnpb.StateClearResponse{},
},
}
default:
panic(fmt.Sprintf("unsupported StateRequest kind %T: %v", req.GetRequest(), prototext.Format(req)))
}
}
}()
for resp := range responses {
if err := state.Send(resp); err != nil {
slog.Error("state.Send", slog.Any("error", err))
}
}
return nil
}