client/channel/channel.go (96 lines of code) (raw):

// // Licensed to the Apache Software Foundation (ASF) under one or more // contributor license agreements. See the NOTICE file distributed with // this work for additional information regarding copyright ownership. // The ASF licenses this file to You under the Apache License, Version 2.0 // (the "License"); you may not use this file except in compliance with // the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package channel import ( "crypto/tls" "crypto/x509" "errors" "fmt" "net" "net/url" "strconv" "strings" "golang.org/x/oauth2" "google.golang.org/grpc" "google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials/oauth" ) // Reserved header parameters that must not be injected as variables. var reservedParams = []string{"user_id", "token", "use_ssl"} // The ChannelBuilder is used to parse the different parameters of the connection // string according to the specification documented here: // // https://github.com/apache/spark/blob/master/connector/connect/docs/client-connection-string.md type ChannelBuilder struct { Host string Port int Token string User string Headers map[string]string } // Finalizes the creation of the gprc.ClientConn by creating a GRPC channel // with the necessary options extracted from the connection string. For // TLS connections, this function will load the system certificates. func (cb *ChannelBuilder) Build() (*grpc.ClientConn, error) { var opts []grpc.DialOption opts = append(opts, grpc.WithAuthority(cb.Host)) if cb.Token == "" { opts = append(opts, grpc.WithInsecure()) } else { // Note: On the Windows platform, use of x509.SystemCertPool() requires // go version 1.18 or higher. systemRoots, err := x509.SystemCertPool() if err != nil { return nil, err } cred := credentials.NewTLS(&tls.Config{ RootCAs: systemRoots, }) opts = append(opts, grpc.WithTransportCredentials(cred)) t := oauth2.Token{ AccessToken: cb.Token, TokenType: "bearer", } opts = append(opts, grpc.WithPerRPCCredentials(oauth.NewOauthAccess(&t))) } remote := fmt.Sprintf("%v:%v", cb.Host, cb.Port) conn, err := grpc.Dial(remote, opts...) if err != nil { return nil, fmt.Errorf("failed to connect to remote %s: %w", remote, err) } return conn, nil } // Creates a new instance of the ChannelBuilder. This constructor effectively // parses the connection string and extracts the relevant parameters directly. func NewBuilder(connection string) (*ChannelBuilder, error) { u, err := url.Parse(connection) if err != nil { return nil, err } if u.Scheme != "sc" { return nil, errors.New("URL schema must be set to `sc`.") } var port = 15002 var host = u.Host // Check if the host part of the URL contains a port and extract. if strings.Contains(u.Host, ":") { hostStr, portStr, err := net.SplitHostPort(u.Host) if err != nil { return nil, err } host = hostStr if len(portStr) != 0 { port, err = strconv.Atoi(portStr) if err != nil { return nil, err } } } // Validate that the URL path is empty or follows the right format. if u.Path != "" && !strings.HasPrefix(u.Path, "/;") { return nil, fmt.Errorf("The URL path (%v) must be empty or have a proper parameter syntax.", u.Path) } cb := &ChannelBuilder{ Host: host, Port: port, Headers: map[string]string{}, } elements := strings.Split(u.Path, ";") for _, e := range elements { props := strings.Split(e, "=") if len(props) == 2 { if props[0] == "token" { cb.Token = props[1] } else if props[0] == "user_id" { cb.User = props[1] } else { cb.Headers[props[0]] = props[1] } } } return cb, nil }