statefun-sdk-go/v3/pkg/statefun/handler.go (180 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package statefun import ( "bytes" "context" "fmt" "log" "net/http" "sync" "github.com/apache/flink-statefun/statefun-sdk-go/v3/pkg/statefun/internal/protocol" "google.golang.org/protobuf/proto" ) // StatefulFunctions is a registry for multiple StatefulFunction's. A RequestReplyHandler // can be created from the registry that understands how to dispatch // invocation requests to the registered functions as well as encode // side-effects (e.g., sending messages to other functions or updating // values in storage) as the response. type StatefulFunctions interface { // WithSpec registers a StatefulFunctionSpec, which will be // used to build the runtime function. It returns an error // if the specification is invalid and the handler // fails to register the function. WithSpec(spec StatefulFunctionSpec) error // AsHandler creates a RequestReplyHandler from the registered // function specs. AsHandler() RequestReplyHandler } // The RequestReplyHandler processes messages // from the runtime, invokes functions, and encodes // side effects. The handler implements http.Handler // so it can easily be embedded in standard Go server // frameworks. type RequestReplyHandler interface { http.Handler // Invoke method provides compliance with AWS Lambda handler Invoke(ctx context.Context, payload []byte) ([]byte, error) } // StatefulFunctionsBuilder creates a new StatefulFunctions registry. func StatefulFunctionsBuilder() StatefulFunctions { return &handler{ module: map[TypeName]StatefulFunction{}, stateSpecs: map[TypeName]map[string]*protocol.FromFunction_PersistedValueSpec{}, } } type handler struct { module map[TypeName]StatefulFunction stateSpecs map[TypeName]map[string]*protocol.FromFunction_PersistedValueSpec } func (h *handler) WithSpec(spec StatefulFunctionSpec) error { log.Printf("registering Stateful Function %v\n", spec.FunctionType) if _, exists := h.module[spec.FunctionType]; exists { err := fmt.Errorf("failed to register Stateful Function %s, there is already a spec registered under that type", spec.FunctionType) log.Println(err.Error()) return err } if spec.Function == nil { err := fmt.Errorf("failed to register Stateful Function %s, the Function instance cannot be nil", spec.FunctionType) log.Println(err.Error()) return err } valueSpecs := make(map[string]*protocol.FromFunction_PersistedValueSpec, len(spec.States)) for _, state := range spec.States { log.Printf("registering state specification %v\n", state) if err := validateValueSpec(state); err != nil { err := fmt.Errorf("failed to register Stateful Function %s: %w", spec.FunctionType, err) log.Println(err.Error()) return err } expiration := &protocol.FromFunction_ExpirationSpec{} switch state.Expiration.expirationType { case none: expiration.Mode = protocol.FromFunction_ExpirationSpec_NONE case expireAfterWrite: expiration.Mode = protocol.FromFunction_ExpirationSpec_AFTER_WRITE expiration.ExpireAfterMillis = state.Expiration.duration.Milliseconds() case expireAfterCall: expiration.Mode = protocol.FromFunction_ExpirationSpec_AFTER_INVOKE expiration.ExpireAfterMillis = state.Expiration.duration.Milliseconds() } valueSpecs[state.Name] = &protocol.FromFunction_PersistedValueSpec{ StateName: state.Name, ExpirationSpec: expiration, TypeTypename: state.ValueType.GetTypeName().String(), } } h.module[spec.FunctionType] = spec.Function h.stateSpecs[spec.FunctionType] = valueSpecs return nil } func (h *handler) AsHandler() RequestReplyHandler { return h } func (h *handler) ServeHTTP(writer http.ResponseWriter, request *http.Request) { if request.Method != "POST" { http.Error(writer, "invalid request method", http.StatusMethodNotAllowed) return } contentType := request.Header.Get("Content-type") if contentType != "" && contentType != "application/octet-stream" { http.Error(writer, "invalid content type", http.StatusUnsupportedMediaType) return } if request.Body == nil || request.ContentLength == 0 { http.Error(writer, "empty request body", http.StatusBadRequest) return } buffer := bytes.Buffer{} if _, err := buffer.ReadFrom(request.Body); err != nil { http.Error(writer, err.Error(), http.StatusBadRequest) return } response, err := h.Invoke(request.Context(), buffer.Bytes()) if err != nil { log.Println(err.Error()) http.Error(writer, err.Error(), http.StatusInternalServerError) return } _, _ = writer.Write(response) } func (h *handler) Invoke(ctx context.Context, payload []byte) ([]byte, error) { toFunction := protocol.ToFunction{} if err := proto.Unmarshal(payload, &toFunction); err != nil { return nil, fmt.Errorf("failed to unmarshal ToFunction: %w", err) } fromFunction, err := h.invoke(ctx, &toFunction) if err != nil { return nil, err } return proto.Marshal(fromFunction) } func (h *handler) invoke(ctx context.Context, toFunction *protocol.ToFunction) (from *protocol.FromFunction, err error) { batch := toFunction.GetInvocation() self := addressFromInternal(batch.Target) function, exists := h.module[self.FunctionType] defer func() { if r := recover(); r != nil { switch r := r.(type) { case error: err = fmt.Errorf("failed to execute invocation for %s: %w", batch.Target, r) default: log.Fatal(r) } } }() if !exists { return nil, fmt.Errorf("unknown function type %s", self.FunctionType) } storageFactory := newStorageFactory(batch, h.stateSpecs[self.FunctionType]) if missing := storageFactory.getMissingSpecs(); missing != nil { log.Printf("missing state specs for function type %v", self) for _, spec := range missing { log.Printf("registering missing specs %v", spec) } return &protocol.FromFunction{ Response: &protocol.FromFunction_IncompleteInvocationContext_{ IncompleteInvocationContext: &protocol.FromFunction_IncompleteInvocationContext{ MissingValues: missing, }, }, }, nil } storage := storageFactory.getStorage() response := &protocol.FromFunction_InvocationResponse{} for _, invocation := range batch.Invocations { select { case <-ctx.Done(): return nil, ctx.Err() default: sContext := statefunContext{ Mutex: new(sync.Mutex), self: self, storage: storage, response: response, } var cancel context.CancelFunc sContext.Context, cancel = context.WithCancel(ctx) if invocation.Caller != nil { caller := addressFromInternal(invocation.Caller) sContext.caller = &caller } msg := Message{ target: batch.Target, typedValue: invocation.Argument, } err = function.Invoke(&sContext, msg) cancel() if err != nil { return } } } response.StateMutations = storage.getStateMutations() from = &protocol.FromFunction{ Response: &protocol.FromFunction_InvocationResult{ InvocationResult: response, }, } return }