middleware/peeked_stream.go (48 lines of code) (raw):

package middleware import ( "context" "fmt" "google.golang.org/grpc" "google.golang.org/protobuf/proto" ) // NewPeekedStream returns a new grpc.ServerStream which allows for // peeking the first message of ServerStream. Reading the first message // would leave handler unable to read the first message as it was // already consumed. PeekedStream allows for restoring the stream so // the RPC handler can read the first message as usual. func NewPeekedStream( ctx context.Context, firstMessage proto.Message, firstError error, ss grpc.ServerStream, ) grpc.ServerStream { return &peekedStream{ context: ctx, firstMessage: firstMessage, firstError: firstError, ServerStream: ss, } } type peekedStream struct { context context.Context firstMessage proto.Message firstError error grpc.ServerStream } // Context provides the context of the peekedStream. func (ps *peekedStream) Context() context.Context { return ps.context } // RecvMsg wraps the RecvMsg from grpc.ServerStream to read the first message. func (ps *peekedStream) RecvMsg(dst any) error { if ps.firstError != nil { firstError := ps.firstError ps.firstError = nil return firstError } if ps.firstMessage != nil { marshaled, err := proto.Marshal(ps.firstMessage) if err != nil { return fmt.Errorf("marshal: %w", err) } if err := proto.Unmarshal(marshaled, dst.(proto.Message)); err != nil { return fmt.Errorf("unmarshal: %w", err) } ps.firstMessage = nil return nil } return ps.ServerStream.RecvMsg(dst) }