arrow/flight/client.go (234 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package flight
import (
"context"
"encoding/base64"
"io"
"runtime"
"strings"
"sync/atomic"
"golang.org/x/xerrors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
"google.golang.org/grpc/status"
)
// Client is an interface wrapped around the generated FlightServiceClient which is
// generated by grpc protobuf definitions. This interface provides a useful hiding
// of the authentication handshake via calling Authenticate and using the
// ClientAuthHandler rather than manually having to implement the grpc communication
// and sending of the auth token.
type Client interface {
// Authenticate uses the ClientAuthHandler that was used when creating the client
// in order to use the Handshake endpoints of the service.
Authenticate(context.Context, ...grpc.CallOption) error
AuthenticateBasicToken(ctx context.Context, username string, password string, opts ...grpc.CallOption) (context.Context, error)
Close() error
// join the interface from the FlightServiceClient instead of re-defining all
// the endpoints here.
FlightServiceClient
}
type CustomClientMiddleware interface {
StartCall(ctx context.Context) context.Context
}
type ClientPostCallMiddleware interface {
CallCompleted(ctx context.Context, err error)
}
type ClientHeadersMiddleware interface {
HeadersReceived(ctx context.Context, md metadata.MD)
}
func CreateClientMiddleware(middleware CustomClientMiddleware) ClientMiddleware {
return ClientMiddleware{
Unary: func(ctx context.Context, method string, req, reply interface{}, cc *grpc.ClientConn, invoker grpc.UnaryInvoker, opts ...grpc.CallOption) error {
nctx := middleware.StartCall(ctx)
if nctx != nil {
ctx = nctx
}
if hdrs, ok := middleware.(ClientHeadersMiddleware); ok {
hdrmd := make(metadata.MD)
trailermd := make(metadata.MD)
opts = append(opts, grpc.Header(&hdrmd), grpc.Trailer(&trailermd))
defer func() {
hdrs.HeadersReceived(ctx, metadata.Join(hdrmd, trailermd))
}()
}
err := invoker(ctx, method, req, reply, cc, opts...)
if post, ok := middleware.(ClientPostCallMiddleware); ok {
post.CallCompleted(ctx, err)
}
return err
},
Stream: func(ctx context.Context, desc *grpc.StreamDesc, cc *grpc.ClientConn, method string, streamer grpc.Streamer, opts ...grpc.CallOption) (grpc.ClientStream, error) {
nctx := middleware.StartCall(ctx)
if nctx != nil {
ctx = nctx
}
cs, err := streamer(ctx, desc, cc, method, opts...)
hdrs, isHdrs := middleware.(ClientHeadersMiddleware)
post, isPostcall := middleware.(ClientPostCallMiddleware)
if !isPostcall && !isHdrs {
return cs, err
}
if err != nil {
if isHdrs {
md, _ := cs.Header()
hdrs.HeadersReceived(ctx, metadata.Join(md, cs.Trailer()))
}
if isPostcall {
post.CallCompleted(ctx, err)
}
return cs, err
}
// Grab the client stream context because when the finish function or the goroutine below will be
// executed it's not guaranteed cs.Context() will be valid.
csCtx := cs.Context()
finishChan := make(chan struct{})
isFinished := new(int32)
*isFinished = 0
finishFunc := func(err error) {
// since there are multiple code paths that could call finishFunc
// we need some sort of synchronization to guard against multiple
// calls to finish
if !atomic.CompareAndSwapInt32(isFinished, 0, 1) {
return
}
close(finishChan)
if isPostcall {
post.CallCompleted(csCtx, err)
}
if isHdrs {
hdrmd, _ := cs.Header()
hdrs.HeadersReceived(csCtx, metadata.Join(hdrmd, cs.Trailer()))
}
}
go func() {
select {
case <-finishChan:
// finish is being called by something else, no action necessary
case <-csCtx.Done():
finishFunc(csCtx.Err())
}
}()
newCS := &clientStream{
ClientStream: cs,
desc: desc,
finishFn: finishFunc,
}
// The `ClientStream` interface allows one to omit calling `Recv` if it's
// known that the result will be `io.EOF`. See
// http://stackoverflow.com/q/42915337
// In such cases, there's nothing that triggers the span to finish. We,
// therefore, set a finalizer so that the span and the context goroutine will
// at least be cleaned up when the garbage collector is run.
runtime.SetFinalizer(newCS, func(newcs *clientStream) {
newcs.finishFn(nil)
})
return newCS, nil
},
}
}
type clientStream struct {
grpc.ClientStream
desc *grpc.StreamDesc
finishFn func(error)
}
func (cs *clientStream) Header() (metadata.MD, error) {
md, err := cs.ClientStream.Header()
if err != nil {
cs.finishFn(err)
}
return md, err
}
func (cs *clientStream) SendMsg(m interface{}) error {
err := cs.ClientStream.SendMsg(m)
if err != nil {
cs.finishFn(err)
}
return err
}
func (cs *clientStream) RecvMsg(m interface{}) error {
err := cs.ClientStream.RecvMsg(m)
if err == io.EOF {
cs.finishFn(nil)
return err
} else if err != nil {
cs.finishFn(err)
return err
}
if !cs.desc.ServerStreams {
cs.finishFn(nil)
}
return err
}
func (cs *clientStream) CloseSend() error {
err := cs.ClientStream.CloseSend()
if err != nil {
cs.finishFn(err)
}
return err
}
type ClientMiddleware struct {
Stream grpc.StreamClientInterceptor
Unary grpc.UnaryClientInterceptor
}
type client struct {
conn *grpc.ClientConn
authHandler ClientAuthHandler
FlightServiceClient
}
// NewFlightClient takes in the address of the grpc server and an auth handler for the
// application-level handshake. If using TLS or other grpc configurations they can still
// be passed via the grpc.DialOption list just as if connecting manually without this
// helper function.
//
// Alternatively, a grpc client can be constructed as normal without this helper as the
// grpc generated client code is still exported. This exists to add utility and helpers
// around the authentication and passing the token with requests.
//
// Deprecated: prefer to use NewClientWithMiddleware
func NewFlightClient(addr string, auth ClientAuthHandler, opts ...grpc.DialOption) (Client, error) {
if auth != nil {
opts = append([]grpc.DialOption{
grpc.WithChainStreamInterceptor(createClientAuthStreamInterceptor(auth)),
grpc.WithChainUnaryInterceptor(createClientAuthUnaryInterceptor(auth)),
}, opts...)
}
conn, err := grpc.Dial(addr, opts...)
if err != nil {
return nil, err
}
return &client{conn: conn, FlightServiceClient: NewFlightServiceClient(conn), authHandler: auth}, nil
}
// NewClientWithMiddleware takes a slice of middlewares in addition to the auth and address which will be
// used by grpc and chained, the first middleware will be the outer most with the last middleware
// being the inner most wrapper around the actual call. It also passes along the dialoptions passed in such
// as TLS certs and so on.
func NewClientWithMiddleware(addr string, auth ClientAuthHandler, middleware []ClientMiddleware, opts ...grpc.DialOption) (Client, error) {
unary := make([]grpc.UnaryClientInterceptor, 0, len(middleware))
stream := make([]grpc.StreamClientInterceptor, 0, len(middleware))
if auth != nil {
unary = append(unary, createClientAuthUnaryInterceptor(auth))
stream = append(stream, createClientAuthStreamInterceptor(auth))
}
if len(middleware) > 0 {
for _, m := range middleware {
if m.Unary != nil {
unary = append(unary, m.Unary)
}
if m.Stream != nil {
stream = append(stream, m.Stream)
}
}
}
opts = append(opts, grpc.WithChainUnaryInterceptor(unary...), grpc.WithChainStreamInterceptor(stream...))
conn, err := grpc.Dial(addr, opts...)
if err != nil {
return nil, err
}
return &client{conn: conn, FlightServiceClient: NewFlightServiceClient(conn), authHandler: auth}, nil
}
func (c *client) AuthenticateBasicToken(ctx context.Context, username, password string, opts ...grpc.CallOption) (context.Context, error) {
authCtx := metadata.AppendToOutgoingContext(ctx, "Authorization", "Basic "+base64.RawStdEncoding.EncodeToString([]byte(strings.Join([]string{username, password}, ":"))))
stream, err := c.FlightServiceClient.Handshake(authCtx, opts...)
if err != nil {
return ctx, err
}
header, err := stream.Header()
if err != nil {
return ctx, err
}
_, err = stream.Recv()
if err != nil && err != io.EOF {
return ctx, err
}
err = stream.CloseSend()
if err != nil {
return ctx, err
}
meta := stream.Trailer()
md := metadata.Join(header, meta)
for _, token := range md.Get("authorization") {
if token != "" {
return metadata.AppendToOutgoingContext(ctx, "Authorization", token), nil
}
}
return ctx, xerrors.Errorf("flight: no authorization header on the response")
}
func (c *client) Authenticate(ctx context.Context, opts ...grpc.CallOption) error {
if c.authHandler == nil {
return status.Error(codes.NotFound, "cannot authenticate without an auth-handler")
}
stream, err := c.FlightServiceClient.Handshake(ctx, opts...)
if err != nil {
return err
}
return c.authHandler.Authenticate(ctx, &clientAuthConn{stream})
}
func (c *client) Close() error {
c.FlightServiceClient = nil
return c.conn.Close()
}