arrow/flight/server.go (240 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package flight import ( "context" "fmt" "net" "os" "os/signal" "github.com/apache/arrow-go/v18/arrow/flight/gen/flight" "google.golang.org/grpc" ) type ( FlightServer = flight.FlightServiceServer FlightService_HandshakeServer = flight.FlightService_HandshakeServer HandshakeResponse = flight.HandshakeResponse HandshakeRequest = flight.HandshakeRequest FlightService_ListFlightsServer = flight.FlightService_ListFlightsServer FlightService_DoGetServer = flight.FlightService_DoGetServer FlightService_DoPutServer = flight.FlightService_DoPutServer FlightService_DoExchangeServer = flight.FlightService_DoExchangeServer FlightService_DoActionServer = flight.FlightService_DoActionServer FlightService_ListActionsServer = flight.FlightService_ListActionsServer Criteria = flight.Criteria FlightDescriptor = flight.FlightDescriptor FlightEndpoint = flight.FlightEndpoint Location = flight.Location FlightInfo = flight.FlightInfo PollInfo = flight.PollInfo FlightData = flight.FlightData PutResult = flight.PutResult Ticket = flight.Ticket SchemaResult = flight.SchemaResult Action = flight.Action ActionType = flight.ActionType CancelFlightInfoRequest = flight.CancelFlightInfoRequest RenewFlightEndpointRequest = flight.RenewFlightEndpointRequest Result = flight.Result CancelFlightInfoResult = flight.CancelFlightInfoResult CancelStatus = flight.CancelStatus SessionOptionValue = flight.SessionOptionValue SetSessionOptionsRequest = flight.SetSessionOptionsRequest SetSessionOptionsResult = flight.SetSessionOptionsResult SetSessionOptionsResultError = flight.SetSessionOptionsResult_Error GetSessionOptionsRequest = flight.GetSessionOptionsRequest GetSessionOptionsResult = flight.GetSessionOptionsResult CloseSessionRequest = flight.CloseSessionRequest CloseSessionResult = flight.CloseSessionResult Empty = flight.Empty ) // Constants for Action types const ( CancelFlightInfoActionType = "CancelFlightInfo" RenewFlightEndpointActionType = "RenewFlightEndpoint" SetSessionOptionsActionType = "SetSessionOptions" GetSessionOptionsActionType = "GetSessionOptions" CloseSessionActionType = "CloseSession" ) const ( // The set option error is unknown. Servers should avoid // using this value (send a NOT_FOUND error if the requested // FlightInfo is not known). Clients can retry the request. SetSessionOptionsResultErrorUnspecified = flight.SetSessionOptionsResult_UNSPECIFIED // The given session option name is invalid. SetSessionOptionsResultErrorInvalidName = flight.SetSessionOptionsResult_INVALID_NAME // The session option value or type is invalid. SetSessionOptionsResultErrorInvalidValue = flight.SetSessionOptionsResult_INVALID_VALUE // The session option cannot be set. SetSessionOptionsResultErrorError = flight.SetSessionOptionsResult_ERROR ) const ( // The close session status is unknown. Servers should avoid // using this value (send a NOT_FOUND error if the requested // FlightInfo is not known). Clients can retry the request. CloseSessionResultUnspecified = flight.CloseSessionResult_UNSPECIFIED // The session close request is complete. CloseSessionResultClosed = flight.CloseSessionResult_CLOSED // The session close request is in progress. The client may retry the request. CloseSessionResultClosing = flight.CloseSessionResult_CLOSING // The session is not closeable. CloseSessionResultNotCloseable = flight.CloseSessionResult_NOT_CLOSEABLE ) // NewSessionOptionValues returns a map with the same keys as the input map, but with all values converted // to SessionOptionValues. If any values fail conversion, an error will be returned. func NewSessionOptionValues(options map[string]any) (map[string]*flight.SessionOptionValue, error) { sessionOptions := make(map[string]*flight.SessionOptionValue, len(options)) for key, val := range options { optval, err := NewSessionOptionValue(val) if err != nil { return nil, err } sessionOptions[key] = &optval } return sessionOptions, nil } // NewSessionOptionValue takes any value and constructs a SessionOptionValue suitable for setting session values. // An error will be returned if the value is not one of the types supported by SessionOptionValue. func NewSessionOptionValue(value any) (flight.SessionOptionValue, error) { if value == nil { return flight.SessionOptionValue{}, nil } switch val := value.(type) { case string: return flight.SessionOptionValue{OptionValue: &flight.SessionOptionValue_StringValue{StringValue: val}}, nil case bool: return flight.SessionOptionValue{OptionValue: &flight.SessionOptionValue_BoolValue{BoolValue: val}}, nil case int64: return flight.SessionOptionValue{OptionValue: &flight.SessionOptionValue_Int64Value{Int64Value: val}}, nil case float64: return flight.SessionOptionValue{OptionValue: &flight.SessionOptionValue_DoubleValue{DoubleValue: val}}, nil case []string: return flight.SessionOptionValue{OptionValue: &flight.SessionOptionValue_StringListValue_{StringListValue: &flight.SessionOptionValue_StringListValue{Values: val}}}, nil default: return flight.SessionOptionValue{}, fmt.Errorf("invalid option type %[1]T for value %[1]v", val) } } // Constants for CancelStatus const ( // The cancellation status is unknown. Servers should avoid // using this value (send a NOT_FOUND error if the requested // FlightInfo is not known). Clients can retry the request. CancelStatusUnspecified = flight.CancelStatus_CANCEL_STATUS_UNSPECIFIED // The cancellation request is complete. Subsequent requests // with the same payload may return CancelStatusCancelled or a // arrow.ErrNotFound error. CancelStatusCancelled = flight.CancelStatus_CANCEL_STATUS_CANCELLED // The cancellation request is in progress. The client may // retry the cancellation request. CancelStatusCancelling = flight.CancelStatus_CANCEL_STATUS_CANCELLING // The FlightInfo is not cancellable. The client should not // retry the cancellation request. CancelStatusNotCancellable = flight.CancelStatus_CANCEL_STATUS_NOT_CANCELLABLE ) // Constants for Location const ( // LocationReuseConnection is a special location that tells clients // they may fetch the data from the same service that they obtained // the FlightEndpoint response from. LocationReuseConnection = "arrow-flight-reuse-connection://?" ) // RegisterFlightServiceServer registers an existing flight server onto an // existing grpc server, or anything that is a grpc service registrar. func RegisterFlightServiceServer(s *grpc.Server, srv FlightServer) { flight.RegisterFlightServiceServer(s, srv) } // From https://github.com/grpc/grpc-go/blob/4c776ec01572d55249df309251900554b46adb41/reflection/serverreflection.go#L69-L83 // This interface is inlined to make this arrow library compatible with // grpc < 1.45 . // See "google.golang.org/grpc/reflection" 's reflection.ServiceInfoProvider . // serviceInfoProvider is an interface used to retrieve metadata about the // services to expose. // // The reflection service is only interested in the service names, but the // signature is this way so that *grpc.Server implements it. So it is okay // for a custom implementation to return zero values for the // grpc.ServiceInfo values in the map. // // # Experimental // // Notice: This type is EXPERIMENTAL and may be changed or removed in a // later release. type serviceInfoProvider interface { GetServiceInfo() map[string]grpc.ServiceInfo } // Server is an interface for hiding some of the grpc specifics to make // it slightly easier to manage a flight service, slightly modeled after // the C++ implementation type Server interface { // Init takes in the address to bind to and creates the listener. If both this // and InitListener are called, then whichever was called last will be used. Init(addr string) error // InitListener initializes with an already created listener rather than // creating a new one like Init does. If both this and Init are called, // whichever was called last is what will be used as they both set a listener // into the server. InitListener(lis net.Listener) // Addr will return the address that was bound to for the service to listen on Addr() net.Addr // SetShutdownOnSignals sets notifications on the given signals to call GracefulStop // on the grpc service if any of those signals are received SetShutdownOnSignals(sig ...os.Signal) // Serve blocks until accepting a connection fails with a fatal error. It will return // a non-nil error unless it stopped due to calling Shutdown or receiving one of the // signals set in SetShutdownOnSignals Serve() error // Shutdown will call GracefulStop on the grpc server so that it stops accepting connections // and will wait until current methods complete Shutdown() // RegisterFlightService sets up the handler for the Flight Endpoints as per // normal Grpc setups RegisterFlightService(FlightServer) // ServiceRegistrar wraps a single method that supports service registration. // For example, it may be used to register health check provided by grpc-go. grpc.ServiceRegistrar // serviceInfoProvider is an interface used to retrieve metadata about the services to expose. // If reflection is enabled on the server, all the endpoints can be invoked using grpcurl. serviceInfoProvider } // BaseFlightServer is the base flight server implementation and must be // embedded in any server implementation to ensure forward compatibility // with any modifications of the spec without compiler errors. type BaseFlightServer struct { flight.UnimplementedFlightServiceServer authHandler ServerAuthHandler } func (s *BaseFlightServer) GetAuthHandler() ServerAuthHandler { return s.authHandler } func (s *BaseFlightServer) SetAuthHandler(handler ServerAuthHandler) { s.authHandler = handler } func (s *BaseFlightServer) Handshake(stream flight.FlightService_HandshakeServer) error { if s.authHandler == nil { return nil } return s.authHandler.Authenticate(&serverAuthConn{stream}) } // CustomerServerMiddleware is a helper interface for more easily defining custom // grpc middleware without having to expose or understand all the grpc bells and whistles. type CustomServerMiddleware interface { // StartCall will be called with the current context of the call, grpc.SetHeader can be used to add outgoing headers // if the returned context is non-nil, then it will be used as the new context being passed through the calls StartCall(ctx context.Context) context.Context // CallCompleted is a callback which is called with the return from the handler // it will be nil if everything was successful or will be the error about to be returned // to grpc CallCompleted(ctx context.Context, err error) } // CreateServerMiddlware constructs a ServerMiddleware object for the passed in custom // middleware, generating both the Unary and Stream interceptors from the interface. func CreateServerMiddleware(middleware CustomServerMiddleware) ServerMiddleware { return ServerMiddleware{ Unary: func(ctx context.Context, req interface{}, _ *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (ret interface{}, err error) { nctx := middleware.StartCall(ctx) if nctx != nil { ctx = nctx } ret, err = handler(ctx, req) middleware.CallCompleted(ctx, err) return }, Stream: func(srv interface{}, stream grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { ctx := middleware.StartCall(stream.Context()) if ctx != nil { stream = &wrappedStream{ServerStream: stream, ctx: ctx} } err := handler(srv, stream) middleware.CallCompleted(stream.Context(), err) return err }, } } type ServerMiddleware struct { Stream grpc.StreamServerInterceptor Unary grpc.UnaryServerInterceptor } type server struct { lis net.Listener sigChannel <-chan os.Signal done chan bool server *grpc.Server } // NewFlightServer takes any grpc Server options desired, such as TLS certs and so // on which will just be passed through to the underlying grpc server. // // Alternatively, a grpc server can be created normally without this helper as the // grpc server generated code is still being exported. This only exists to allow // the utility of the helpers // // Deprecated: prefer to use NewServerWithMiddleware, due to auth handler middleware // this function will be problematic if any of the grpc options specify other middleware. func NewFlightServer(opt ...grpc.ServerOption) Server { opt = append([]grpc.ServerOption{ grpc.ChainStreamInterceptor(serverAuthStreamInterceptor), grpc.ChainUnaryInterceptor(serverAuthUnaryInterceptor), }, opt...) return &server{ server: grpc.NewServer(opt...), } } // NewServerWithMiddleware takes a slice of middleware which will be used // by grpc and chained, the first middleware will be the outer most with the last // middleware being the inner most wrapper around the actual call. It also takes // any grpc Server options desired, such as TLS certs and so on which will just // be passed through to the underlying grpc server. // // Because of the usage of `ChainStreamInterceptor` and `ChainUnaryInterceptor` do // not specify any middleware using the grpc options, use the ServerMiddleware slice // instead as the auth middleware will be added for handling the case that a service // handler is registered that uses the ServerAuthHandler. // // Alternatively, a grpc server can be created normally without this helper as the // grpc server generated code is still being exported. This only exists to allow // the utility of the helpers. func NewServerWithMiddleware(middleware []ServerMiddleware, opts ...grpc.ServerOption) Server { unary := make([]grpc.UnaryServerInterceptor, 1, len(middleware)+1) unary[0] = serverAuthUnaryInterceptor stream := make([]grpc.StreamServerInterceptor, 1, len(middleware)+1) stream[0] = serverAuthStreamInterceptor if len(middleware) > 0 { for _, m := range middleware { if m.Unary != nil { unary = append(unary, m.Unary) } if m.Stream != nil { stream = append(stream, m.Stream) } } } opts = append(opts, grpc.ChainUnaryInterceptor(unary...), grpc.ChainStreamInterceptor(stream...)) return &server{server: grpc.NewServer(opts...)} } func (s *server) Init(addr string) (err error) { s.lis, err = net.Listen("tcp", addr) return } func (s *server) InitListener(lis net.Listener) { s.lis = lis } func (s *server) Addr() net.Addr { return s.lis.Addr() } func (s *server) SetShutdownOnSignals(sig ...os.Signal) { c := make(chan os.Signal, 1) signal.Notify(c, sig...) s.sigChannel = c } func (s *server) Serve() error { s.done = make(chan bool) go func() { select { case <-s.sigChannel: s.server.GracefulStop() case <-s.done: } }() err := s.server.Serve(s.lis) close(s.done) return err } func (s *server) RegisterFlightService(svc FlightServer) { flight.RegisterFlightServiceServer(s.server, svc) } func (s *server) Shutdown() { s.server.GracefulStop() } func (s *server) RegisterService(sd *grpc.ServiceDesc, ss interface{}) { s.server.RegisterService(sd, ss) } func (s *server) GetServiceInfo() map[string]grpc.ServiceInfo { return s.server.GetServiceInfo() }