arrow/flight/client.go (234 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package flight import ( "context" "encoding/base64" "io" "runtime" "strings" "sync/atomic" "golang.org/x/xerrors" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" ) // Client is an interface wrapped around the generated FlightServiceClient which is // generated by grpc protobuf definitions. This interface provides a useful hiding // of the authentication handshake via calling Authenticate and using the // ClientAuthHandler rather than manually having to implement the grpc communication // and sending of the auth token. type Client interface { // Authenticate uses the ClientAuthHandler that was used when creating the client // in order to use the Handshake endpoints of the service. Authenticate(context.Context, ...grpc.CallOption) error AuthenticateBasicToken(ctx context.Context, username string, password string, opts ...grpc.CallOption) (context.Context, error) Close() error // join the interface from the FlightServiceClient instead of re-defining all // the endpoints here. FlightServiceClient } type CustomClientMiddleware interface { StartCall(ctx context.Context) context.Context } type ClientPostCallMiddleware interface { CallCompleted(ctx context.Context, err error) } type ClientHeadersMiddleware interface { HeadersReceived(ctx context.Context, md metadata.MD) } func CreateClientMiddleware(middleware CustomClientMiddleware) ClientMiddleware { return ClientMiddleware{ Unary: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error { nctx := middleware.StartCall(ctx) if nctx != nil { ctx = nctx } if hdrs, ok := middleware.(ClientHeadersMiddleware); ok { hdrmd := make(metadata.MD) trailermd := make(metadata.MD) opts = append(opts, grpc.Header(&hdrmd), grpc.Trailer(&trailermd)) defer func() { hdrs.HeadersReceived(ctx, metadata.Join(hdrmd, trailermd)) }() } err := invoker(ctx, method, req, reply, cc, opts...) if post, ok := middleware.(ClientPostCallMiddleware); ok { post.CallCompleted(ctx, err) } return err }, Stream: func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) { nctx := middleware.StartCall(ctx) if nctx != nil { ctx = nctx } cs, err := streamer(ctx, desc, cc, method, opts...) hdrs, isHdrs := middleware.(ClientHeadersMiddleware) post, isPostcall := middleware.(ClientPostCallMiddleware) if !isPostcall && !isHdrs { return cs, err } if err != nil { if isHdrs { md, _ := cs.Header() hdrs.HeadersReceived(ctx, metadata.Join(md, cs.Trailer())) } if isPostcall { post.CallCompleted(ctx, err) } return cs, err } // Grab the client stream context because when the finish function or the goroutine below will be // executed it's not guaranteed cs.Context() will be valid. csCtx := cs.Context() finishChan := make(chan struct{}) isFinished := new(int32) *isFinished = 0 finishFunc := func(err error) { // since there are multiple code paths that could call finishFunc // we need some sort of synchronization to guard against multiple // calls to finish if !atomic.CompareAndSwapInt32(isFinished, 0, 1) { return } close(finishChan) if isPostcall { post.CallCompleted(csCtx, err) } if isHdrs { hdrmd, _ := cs.Header() hdrs.HeadersReceived(csCtx, metadata.Join(hdrmd, cs.Trailer())) } } go func() { select { case <-finishChan: // finish is being called by something else, no action necessary case <-csCtx.Done(): finishFunc(csCtx.Err()) } }() newCS := &clientStream{ ClientStream: cs, desc: desc, finishFn: finishFunc, } // The `ClientStream` interface allows one to omit calling `Recv` if it's // known that the result will be `io.EOF`. See // http://stackoverflow.com/q/42915337 // In such cases, there's nothing that triggers the span to finish. We, // therefore, set a finalizer so that the span and the context goroutine will // at least be cleaned up when the garbage collector is run. runtime.SetFinalizer(newCS, func(newcs *clientStream) { newcs.finishFn(nil) }) return newCS, nil }, } } type clientStream struct { grpc.ClientStream desc *grpc.StreamDesc finishFn func(error) } func (cs *clientStream) Header() (metadata.MD, error) { md, err := cs.ClientStream.Header() if err != nil { cs.finishFn(err) } return md, err } func (cs *clientStream) SendMsg(m interface{}) error { err := cs.ClientStream.SendMsg(m) if err != nil { cs.finishFn(err) } return err } func (cs *clientStream) RecvMsg(m interface{}) error { err := cs.ClientStream.RecvMsg(m) if err == io.EOF { cs.finishFn(nil) return err } else if err != nil { cs.finishFn(err) return err } if !cs.desc.ServerStreams { cs.finishFn(nil) } return err } func (cs *clientStream) CloseSend() error { err := cs.ClientStream.CloseSend() if err != nil { cs.finishFn(err) } return err } type ClientMiddleware struct { Stream grpc.StreamClientInterceptor Unary grpc.UnaryClientInterceptor } type client struct { conn *grpc.ClientConn authHandler ClientAuthHandler FlightServiceClient } // NewFlightClient takes in the address of the grpc server and an auth handler for the // application-level handshake. If using TLS or other grpc configurations they can still // be passed via the grpc.DialOption list just as if connecting manually without this // helper function. // // Alternatively, a grpc client can be constructed as normal without this helper as the // grpc generated client code is still exported. This exists to add utility and helpers // around the authentication and passing the token with requests. // // Deprecated: prefer to use NewClientWithMiddleware func NewFlightClient(addr string, auth ClientAuthHandler, opts ...grpc.DialOption) (Client, error) { if auth != nil { opts = append([]grpc.DialOption{ grpc.WithChainStreamInterceptor(createClientAuthStreamInterceptor(auth)), grpc.WithChainUnaryInterceptor(createClientAuthUnaryInterceptor(auth)), }, opts...) } conn, err := grpc.Dial(addr, opts...) if err != nil { return nil, err } return &client{conn: conn, FlightServiceClient: NewFlightServiceClient(conn), authHandler: auth}, nil } // NewClientWithMiddleware takes a slice of middlewares in addition to the auth and address which will be // used by grpc and chained, the first middleware will be the outer most with the last middleware // being the inner most wrapper around the actual call. It also passes along the dialoptions passed in such // as TLS certs and so on. func NewClientWithMiddleware(addr string, auth ClientAuthHandler, middleware []ClientMiddleware, opts ...grpc.DialOption) (Client, error) { unary := make([]grpc.UnaryClientInterceptor, 0, len(middleware)) stream := make([]grpc.StreamClientInterceptor, 0, len(middleware)) if auth != nil { unary = append(unary, createClientAuthUnaryInterceptor(auth)) stream = append(stream, createClientAuthStreamInterceptor(auth)) } if len(middleware) > 0 { for _, m := range middleware { if m.Unary != nil { unary = append(unary, m.Unary) } if m.Stream != nil { stream = append(stream, m.Stream) } } } opts = append(opts, grpc.WithChainUnaryInterceptor(unary...), grpc.WithChainStreamInterceptor(stream...)) conn, err := grpc.Dial(addr, opts...) if err != nil { return nil, err } return &client{conn: conn, FlightServiceClient: NewFlightServiceClient(conn), authHandler: auth}, nil } func (c *client) AuthenticateBasicToken(ctx context.Context, username, password string, opts ...grpc.CallOption) (context.Context, error) { authCtx := metadata.AppendToOutgoingContext(ctx, "Authorization", "Basic "+base64.RawStdEncoding.EncodeToString([]byte(strings.Join([]string{username, password}, ":")))) stream, err := c.FlightServiceClient.Handshake(authCtx, opts...) if err != nil { return ctx, err } header, err := stream.Header() if err != nil { return ctx, err } _, err = stream.Recv() if err != nil && err != io.EOF { return ctx, err } err = stream.CloseSend() if err != nil { return ctx, err } meta := stream.Trailer() md := metadata.Join(header, meta) for _, token := range md.Get("authorization") { if token != "" { return metadata.AppendToOutgoingContext(ctx, "Authorization", token), nil } } return ctx, xerrors.Errorf("flight: no authorization header on the response") } func (c *client) Authenticate(ctx context.Context, opts ...grpc.CallOption) error { if c.authHandler == nil { return status.Error(codes.NotFound, "cannot authenticate without an auth-handler") } stream, err := c.FlightServiceClient.Handshake(ctx, opts...) if err != nil { return err } return c.authHandler.Authenticate(ctx, &clientAuthConn{stream}) } func (c *client) Close() error { c.FlightServiceClient = nil return c.conn.Close() }