arrow/internal/flight_integration/scenario.go (409 lines of code) (raw):
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package flight_integration
import (
"bytes"
"context"
"fmt"
"io"
"net"
"os"
"strconv"
"github.com/aliyun/aliyun-odps-go-sdk/arrow"
"github.com/aliyun/aliyun-odps-go-sdk/arrow/array"
"github.com/aliyun/aliyun-odps-go-sdk/arrow/flight"
"github.com/aliyun/aliyun-odps-go-sdk/arrow/internal/arrjson"
"github.com/aliyun/aliyun-odps-go-sdk/arrow/internal/testing/types"
"github.com/aliyun/aliyun-odps-go-sdk/arrow/ipc"
"github.com/aliyun/aliyun-odps-go-sdk/arrow/memory"
"golang.org/x/xerrors"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/proto"
)
type Scenario interface {
MakeServer(port int) flight.Server
RunClient(addr string, opts ...grpc.DialOption) error
}
func GetScenario(name string, args ...string) Scenario {
switch name {
case "auth:basic_proto":
return &authBasicProtoTester{}
case "middleware":
return &middlewareScenarioTester{}
case "":
if len(args) > 0 {
return &defaultIntegrationTester{path: args[0]}
}
return &defaultIntegrationTester{}
}
panic(fmt.Errorf("scenario not found: %s", name))
}
func initServer(port int, srv flight.Server) int {
srv.Init(fmt.Sprintf("0.0.0.0:%d", port))
_, p, _ := net.SplitHostPort(srv.Addr().String())
port, _ = strconv.Atoi(p)
return port
}
type integrationDataSet struct {
schema *arrow.Schema
chunks []array.Record
}
func consumeFlightLocation(ctx context.Context, loc *flight.Location, tkt *flight.Ticket, orig []array.Record, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(loc.GetUri(), nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
stream, err := client.DoGet(ctx, tkt)
if err != nil {
return err
}
rdr, err := flight.NewRecordReader(stream)
if err != nil {
return err
}
defer rdr.Release()
for i, chunk := range orig {
if !rdr.Next() {
return xerrors.Errorf("got fewer batches than expected, received so far: %d, expected: %d", i, len(orig))
}
if !array.RecordEqual(chunk, rdr.Record()) {
return xerrors.Errorf("batch %d doesn't match", i)
}
if string(rdr.LatestAppMetadata()) != strconv.Itoa(i) {
return xerrors.Errorf("expected metadata value: %s, but got: %s", strconv.Itoa(i), string(rdr.LatestAppMetadata()))
}
}
if rdr.Next() {
return xerrors.Errorf("got more batches than the expected: %d", len(orig))
}
return nil
}
type defaultIntegrationTester struct {
port int
path string
uploadedChunks map[string]integrationDataSet
}
func (s *defaultIntegrationTester) RunClient(addr string, opts ...grpc.DialOption) error {
client, err := flight.NewClientWithMiddleware(addr, nil, nil, opts...)
if err != nil {
return err
}
defer client.Close()
ctx := context.Background()
arrow.RegisterExtensionType(types.NewUUIDType())
defer arrow.UnregisterExtensionType("uuid")
descr := &flight.FlightDescriptor{
Type: flight.FlightDescriptor_PATH,
Path: []string{s.path},
}
fmt.Println("Opening JSON file '", s.path, "'")
r, err := os.Open(s.path)
if err != nil {
return xerrors.Errorf("could not open JSON file: %q: %w", s.path, err)
}
rdr, err := arrjson.NewReader(r)
if err != nil {
return xerrors.Errorf("could not create JSON file reader from file: %q: %w", s.path, err)
}
dataSet := integrationDataSet{
chunks: make([]array.Record, 0),
schema: rdr.Schema(),
}
for {
rec, err := rdr.Read()
if err != nil {
if err == io.EOF {
break
}
return err
}
defer rec.Release()
dataSet.chunks = append(dataSet.chunks, rec)
}
stream, err := client.DoPut(ctx)
if err != nil {
return err
}
wr := flight.NewRecordWriter(stream, ipc.WithSchema(dataSet.schema))
wr.SetFlightDescriptor(descr)
for i, rec := range dataSet.chunks {
metadata := []byte(strconv.Itoa(i))
if err := wr.WriteWithAppMetadata(rec, metadata); err != nil {
return err
}
pr, err := stream.Recv()
if err != nil {
return err
}
acked := pr.GetAppMetadata()
switch {
case len(acked) == 0:
return xerrors.Errorf("expected metadata value: %s, but got nothing.", string(metadata))
case !bytes.Equal(metadata, acked):
return xerrors.Errorf("expected metadata value: %s, but got: %s", string(metadata), string(acked))
}
}
wr.Close()
if err := stream.CloseSend(); err != nil {
return err
}
info, err := client.GetFlightInfo(ctx, descr)
if err != nil {
return err
}
if len(info.Endpoint) == 0 {
fmt.Fprintln(os.Stderr, "no endpoints returned from flight server.")
return xerrors.Errorf("no endpoints returned from flight server")
}
for _, ep := range info.Endpoint {
if len(ep.Location) == 0 {
return xerrors.Errorf("no locations returned from flight server")
}
for _, loc := range ep.Location {
consumeFlightLocation(ctx, loc, ep.Ticket, dataSet.chunks, opts...)
}
}
return nil
}
func (s *defaultIntegrationTester) MakeServer(port int) flight.Server {
s.uploadedChunks = make(map[string]integrationDataSet)
srv := flight.NewServerWithMiddleware(nil, nil)
srv.RegisterFlightService(&flight.FlightServiceService{
GetFlightInfo: s.GetFlightInfo,
DoGet: s.DoGet,
DoPut: s.DoPut,
})
s.port = initServer(port, srv)
return srv
}
func (s *defaultIntegrationTester) GetFlightInfo(ctx context.Context, in *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if in.Type == flight.FlightDescriptor_PATH {
if len(in.Path) == 0 {
return nil, status.Error(codes.InvalidArgument, "invalid path")
}
data, ok := s.uploadedChunks[in.Path[0]]
if !ok {
return nil, status.Errorf(codes.NotFound, "could not find flight: %s", in.Path[0])
}
flightData := &flight.FlightInfo{
Schema: flight.SerializeSchema(data.schema, memory.DefaultAllocator),
FlightDescriptor: in,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{Ticket: []byte(in.Path[0])},
Location: []*flight.Location{{Uri: fmt.Sprintf("grpc+tcp://127.0.0.1:%d", s.port)}},
}},
TotalRecords: 0,
TotalBytes: -1,
}
for _, r := range data.chunks {
flightData.TotalRecords += r.NumRows()
}
return flightData, nil
}
return nil, status.Error(codes.Unimplemented, in.Type.String())
}
func (s *defaultIntegrationTester) DoGet(tkt *flight.Ticket, stream flight.FlightService_DoGetServer) error {
data, ok := s.uploadedChunks[string(tkt.Ticket)]
if !ok {
return status.Errorf(codes.NotFound, "could not find flight: %s", string(tkt.Ticket))
}
wr := flight.NewRecordWriter(stream, ipc.WithSchema(data.schema))
defer wr.Close()
for i, rec := range data.chunks {
wr.WriteWithAppMetadata(rec, []byte(strconv.Itoa(i)))
}
return nil
}
func (s *defaultIntegrationTester) DoPut(stream flight.FlightService_DoPutServer) error {
rdr, err := flight.NewRecordReader(stream)
if err != nil {
return status.Error(codes.Internal, err.Error())
}
var (
key string
dataset integrationDataSet
)
// creating the reader should have gotten the first message which would
// have the schema, which should have a populated flight descriptor
desc := rdr.LatestFlightDescriptor()
if desc.Type != flight.FlightDescriptor_PATH || len(desc.Path) < 1 {
return status.Error(codes.InvalidArgument, "must specify a path")
}
key = desc.Path[0]
dataset.schema = rdr.Schema()
dataset.chunks = make([]array.Record, 0)
for rdr.Next() {
rec := rdr.Record()
rec.Retain()
dataset.chunks = append(dataset.chunks, rec)
if len(rdr.LatestAppMetadata()) > 0 {
stream.Send(&flight.PutResult{AppMetadata: rdr.LatestAppMetadata()})
}
}
s.uploadedChunks[key] = dataset
return nil
}
func CheckActionResults(ctx context.Context, client flight.Client, action *flight.Action, results []string) error {
stream, err := client.DoAction(ctx, action)
if err != nil {
return err
}
defer stream.CloseSend()
for _, expected := range results {
res, err := stream.Recv()
if err != nil {
return err
}
actual := string(res.Body)
if expected != actual {
return xerrors.Errorf("got wrong result: expected: %s, got: %s", expected, actual)
}
}
res, err := stream.Recv()
if res != nil || err != io.EOF {
return xerrors.New("action result stream had too many entries")
}
return nil
}
const (
authUsername = "arrow"
authPassword = "flight"
)
type authBasicValidator struct {
auth flight.BasicAuth
}
func (a *authBasicValidator) Authenticate(conn flight.AuthConn) error {
token, err := conn.Read()
if err != nil {
return err
}
var incoming flight.BasicAuth
if err = proto.Unmarshal(token, &incoming); err != nil {
return err
}
if incoming.Username != a.auth.Username || incoming.Password != a.auth.Password {
return status.Error(codes.Unauthenticated, "invalid token")
}
return conn.Send([]byte(a.auth.Username))
}
func (a *authBasicValidator) IsValid(token string) (interface{}, error) {
if token != a.auth.Username {
return nil, status.Error(codes.Unauthenticated, "invalid token")
}
return token, nil
}
type clientAuthBasic struct {
auth *flight.BasicAuth
token string
}
func (c *clientAuthBasic) Authenticate(_ context.Context, conn flight.AuthConn) error {
if c.auth != nil {
data, err := proto.Marshal(c.auth)
if err != nil {
return err
}
if err = conn.Send(data); err != nil {
return err
}
token, err := conn.Read()
c.token = string(token)
if err != io.EOF {
return err
}
}
return nil
}
func (c *clientAuthBasic) GetToken(context.Context) (string, error) {
return c.token, nil
}
type authBasicProtoTester struct{}
func (s *authBasicProtoTester) RunClient(addr string, opts ...grpc.DialOption) error {
auth := &clientAuthBasic{}
client, err := flight.NewClientWithMiddleware(addr, auth, nil, opts...)
if err != nil {
return err
}
ctx := context.Background()
stream, err := client.DoAction(ctx, &flight.Action{})
if err != nil {
return err
}
// should fail unauthenticated
_, err = stream.Recv()
st, ok := status.FromError(err)
if !ok {
return err
}
if st.Code() != codes.Unauthenticated {
return xerrors.Errorf("expected Unauthenticated, got %s", st.Code())
}
auth.auth = &flight.BasicAuth{Username: authUsername, Password: authPassword}
if err := client.Authenticate(ctx); err != nil {
return err
}
return CheckActionResults(ctx, client, &flight.Action{}, []string{authUsername})
}
func (s *authBasicProtoTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(&authBasicValidator{
auth: flight.BasicAuth{Username: authUsername, Password: authPassword}}, nil)
srv.RegisterFlightService(&flight.FlightServiceService{
DoAction: s.DoAction,
})
initServer(port, srv)
return srv
}
func (authBasicProtoTester) DoAction(_ *flight.Action, stream flight.FlightService_DoActionServer) error {
auth := flight.AuthFromContext(stream.Context())
stream.Send(&flight.Result{Body: []byte(auth.(string))})
return nil
}
type middlewareScenarioTester struct{}
func (m *middlewareScenarioTester) RunClient(addr string, opts ...grpc.DialOption) error {
tm := &testClientMiddleware{}
client, err := flight.NewClientWithMiddleware(addr, nil, []flight.ClientMiddleware{
flight.CreateClientMiddleware(tm)}, opts...)
if err != nil {
return err
}
ctx := context.Background()
// this call is expected to fail
_, err = client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.FlightDescriptor_CMD})
if err == nil {
return xerrors.New("expected call to fail")
}
if tm.received != "expected value" {
return xerrors.Errorf("expected to receive header 'x-middleware: expected value', but instead got %s", tm.received)
}
fmt.Fprintln(os.Stderr, "Headers received successfully on failing call.")
tm.received = ""
_, err = client.GetFlightInfo(ctx, &flight.FlightDescriptor{Type: flight.FlightDescriptor_CMD, Cmd: []byte("success")})
if err != nil {
return err
}
if tm.received != "expected value" {
return xerrors.Errorf("expected to receive header 'x-middleware: expected value', but instead got %s", tm.received)
}
fmt.Fprintln(os.Stderr, "Headers received successfully on passing call.")
return nil
}
func (m *middlewareScenarioTester) MakeServer(port int) flight.Server {
srv := flight.NewServerWithMiddleware(nil, []flight.ServerMiddleware{
flight.CreateServerMiddleware(testServerMiddleware{})})
srv.RegisterFlightService(&flight.FlightServiceService{
GetFlightInfo: m.GetFlightInfo,
})
initServer(port, srv)
return srv
}
func (m *middlewareScenarioTester) GetFlightInfo(ctx context.Context, desc *flight.FlightDescriptor) (*flight.FlightInfo, error) {
if desc.Type != flight.FlightDescriptor_CMD || string(desc.Cmd) != "success" {
return nil, status.Error(codes.Unknown, "unknown")
}
return &flight.FlightInfo{
Schema: flight.SerializeSchema(arrow.NewSchema([]arrow.Field{}, nil), memory.DefaultAllocator),
FlightDescriptor: desc,
Endpoint: []*flight.FlightEndpoint{{
Ticket: &flight.Ticket{Ticket: []byte("foo")},
Location: []*flight.Location{{Uri: "grpc+tcp://localhost:10010"}},
}},
TotalRecords: -1,
TotalBytes: -1,
}, nil
}