spark/client/channel/channel.go (154 lines of code) (raw):
//
// Licensed to the Apache Software Foundation (ASF) under one or more
// contributor license agreements. See the NOTICE file distributed with
// this work for additional information regarding copyright ownership.
// The ASF licenses this file to You under the Apache License, Version 2.0
// (the "License"); you may not use this file except in compliance with
// the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package channel
import (
"context"
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net"
"net/url"
"os"
"runtime"
"strconv"
"strings"
"github.com/apache/spark-connect-go/v35/spark"
"github.com/google/uuid"
"google.golang.org/grpc/credentials/insecure"
"github.com/apache/spark-connect-go/v35/spark/sparkerrors"
"golang.org/x/oauth2"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/oauth"
)
// Builder is the interface that is used to implement different patterns that
// create the GRPC connection.
//
// This allows other consumers to plugin custom authentication and authorization
// handlers without having to extend directly the Spark Connect code.
type Builder interface {
// Build creates the grpc.ClientConn according to the configuration of the builder.
// Implementations are free to provide additional paramters in their implementation
// and simply must satisfy this minimal set of requirements.
Build(ctx context.Context) (*grpc.ClientConn, error)
// User identifies the username passed as part of the Spark Connect requests.
User() string
// Headers refers to the request metadata that is passed for every request from the
// client to the server.
Headers() map[string]string
// SessionId identifies the client side session identifier. This value must be a UUID formatted
// as a string.
SessionId() string
// UserAgent identifies the user agent string that is passed as part of the request. It contains
// information about the operating system, Go version etc.
UserAgent() string
}
// BaseBuilder is used to parse the different parameters of the connection
// string according to the specification documented here:
//
// https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md
type BaseBuilder struct {
host string
port int
token string
user string
headers map[string]string
sessionId string
userAgent string
}
func (cb *BaseBuilder) Host() string {
return cb.host
}
func (cb *BaseBuilder) Port() int {
return cb.port
}
func (cb *BaseBuilder) Token() string {
return cb.token
}
func (cb *BaseBuilder) User() string {
return cb.user
}
func (cb *BaseBuilder) Headers() map[string]string {
return cb.headers
}
func (cb *BaseBuilder) SessionId() string {
return cb.sessionId
}
func (cb *BaseBuilder) UserAgent() string {
return cb.userAgent
}
// Build finalizes the creation of the gprc.ClientConn by creating a GRPC channel
// with the necessary options extracted from the connection string. For
// TLS connections, this function will load the system certificates.
func (cb *BaseBuilder) Build(ctx context.Context) (*grpc.ClientConn, error) {
var opts []grpc.DialOption
opts = append(opts, grpc.WithAuthority(cb.host))
if cb.token == "" {
opts = append(opts, grpc.WithTransportCredentials(insecure.NewCredentials()))
} else {
// Note: On the Windows platform, use of x509.SystemCertPool() requires
// go version 1.18 or higher.
systemRoots, err := x509.SystemCertPool()
if err != nil {
return nil, err
}
cred := credentials.NewTLS(&tls.Config{
RootCAs: systemRoots,
})
opts = append(opts, grpc.WithTransportCredentials(cred))
ts := oauth2.StaticTokenSource(&oauth2.Token{
AccessToken: cb.token,
TokenType: "bearer",
})
opts = append(opts, grpc.WithPerRPCCredentials(oauth.TokenSource{TokenSource: ts}))
}
remote := fmt.Sprintf("%v:%v", cb.host, cb.port)
conn, err := grpc.NewClient(remote, opts...)
if err != nil {
return nil, sparkerrors.WithType(fmt.Errorf("failed to connect to remote %s: %w",
remote, err), sparkerrors.ConnectionError)
}
return conn, nil
}
// NewBuilder creates a new instance of the BaseBuilder. This constructor effectively
// parses the connection string and extracts the relevant parameters directly.
//
// The following parameters to the connection string are reserved: user_id, session_id, use_ssl,
// and token. These parameters are not allowed to be injected as headers.
func NewBuilder(connection string) (*BaseBuilder, error) {
u, err := url.Parse(connection)
if err != nil {
return nil, err
}
if u.Hostname() == "" {
return nil, sparkerrors.WithType(errors.New("URL must contain a hostname"), sparkerrors.InvalidInputError)
}
if u.Scheme != "sc" {
return nil, sparkerrors.WithType(errors.New("URL schema must be set to `sc`"), sparkerrors.InvalidInputError)
}
port := 15002
host := u.Host
// Check if the host part of the URL contains a port and extract.
if strings.Contains(u.Host, ":") {
// We can ignore the error here already since the url parsing
// raises the error about invalid port.
hostStr, portStr, _ := net.SplitHostPort(u.Host)
host = hostStr
if len(portStr) != 0 {
port, err = strconv.Atoi(portStr)
if err != nil {
return nil, err
}
}
}
// Validate that the URL path is empty or follows the right format.
if u.Path != "" && !strings.HasPrefix(u.Path, "/;") {
return nil, sparkerrors.WithType(
fmt.Errorf("the URL path (%v) must be empty or have a proper parameter syntax", u.Path),
sparkerrors.InvalidInputError)
}
cb := &BaseBuilder{
host: host,
port: port,
headers: map[string]string{},
sessionId: uuid.NewString(),
userAgent: "",
}
elements := strings.Split(u.Path, ";")
for _, e := range elements {
props := strings.Split(e, "=")
if len(props) == 2 {
if props[0] == "token" {
cb.token = props[1]
} else if props[0] == "user_id" {
cb.user = props[1]
} else if props[0] == "session_id" {
cb.sessionId = props[1]
} else if props[0] == "user_agent" {
cb.userAgent = props[1]
} else {
cb.headers[props[0]] = props[1]
}
}
}
// Set default user ID if not set.
if cb.user == "" {
cb.user = os.Getenv("USER")
if cb.user == "" {
cb.user = "na"
}
}
// Update the user agent if it is not set or set to a custom value.
val := os.Getenv("SPARK_CONNECT_USER_AGENT")
if cb.userAgent == "" && val != "" {
cb.userAgent = os.Getenv("SPARK_CONNECT_USER_AGENT")
} else if cb.userAgent == "" {
cb.userAgent = "_SPARK_CONNECT_GO"
}
// In addition, to the specified user agent, we need to append information about the
// host encoded as user agent components.
cb.userAgent = fmt.Sprintf("%s spark/%s os/%s go/%s", cb.userAgent, spark.Version(), runtime.GOOS, runtime.Version())
return cb, nil
}