http_client.go (116 lines of code) (raw):

/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to you under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package avatica import ( "bytes" "context" "database/sql/driver" "fmt" "io" "net" "net/http" "regexp" "runtime" "time" avaticaMessage "github.com/apache/calcite-avatica-go/v5/message" "google.golang.org/protobuf/proto" ) var ( badConnRe = regexp.MustCompile(`org\.apache\.calcite\.avatica\.NoSuchConnectionException`) ) // httpClient wraps the default http.Client to communicate with the Avatica server. type httpClient struct { host string httpClient *http.Client } type avaticaError struct { message *avaticaMessage.ErrorResponse } func (e avaticaError) Error() string { return fmt.Sprintf("avatica encountered an error: %s", e.message.GetErrorMessage()) } // NewHTTPClient creates a new httpClient from a host. func NewHTTPClient(host string, baseClient *http.Client, config *Config) (*httpClient, error) { if baseClient == nil { baseClient = &http.Client{ Transport: &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, MaxIdleConnsPerHost: runtime.GOMAXPROCS(0) + 1, }, } switch config.authentication { case digest: baseClient = WithDigestAuth(baseClient, config.avaticaUser, config.avaticaPassword) case basic: baseClient = WithBasicAuth(baseClient, config.avaticaUser, config.avaticaPassword) case spnego: user := config.principal.username realm := config.principal.realm cli, err := WithKerberosAuth(baseClient, user, realm, config.keytab, config.krb5Conf, config.krb5CredentialCache) if err != nil { return nil, fmt.Errorf("can't add kerberos authentication to http client: %w", err) } baseClient = cli } } c := &httpClient{ host: host, httpClient: baseClient, } return c, nil } // post posts a protocol buffer message to the Avatica server. func (c *httpClient) post(ctx context.Context, message proto.Message) (proto.Message, error) { wrapped, err := proto.Marshal(message) if err != nil { return nil, fmt.Errorf("error marshaling request message to protobuf: %w", err) } wire := avaticaMessage.WireMessage_builder{ Name: classNameFromRequest(message), WrappedMessage: wrapped, }.Build() body, err := proto.Marshal(wire) if err != nil { return nil, fmt.Errorf("error marshaling wire message to protobuf: %w", err) } req, err := http.NewRequest("POST", c.host, bytes.NewReader(body)) if err != nil { return nil, fmt.Errorf("error creating http request: %w", err) } req.Header.Set("Content-Type", "application/x-google-protobuf") req = req.WithContext(ctx) res, err := c.httpClient.Do(req) if err != nil { return nil, fmt.Errorf("error executing http request: %w", err) } defer res.Body.Close() response, err := io.ReadAll(res.Body) if err != nil { return nil, fmt.Errorf("error reading response body: %w", err) } result := &avaticaMessage.WireMessage{} err = proto.Unmarshal(response, result) if err != nil { return nil, fmt.Errorf("error unmarshaling wire message: %w", err) } inner, err := responseFromClassName(result.GetName()) if err != nil { return nil, fmt.Errorf("error getting wrapped response from wire message: %w", err) } err = proto.Unmarshal(result.GetWrappedMessage(), inner) if err != nil { return nil, fmt.Errorf("error unmarshaling wrapped message: %w", err) } if v, ok := inner.(*avaticaMessage.ErrorResponse); ok { for _, exception := range v.GetExceptions() { if badConnRe.MatchString(exception) { return nil, driver.ErrBadConn } } return nil, avaticaError{v} } return inner, nil }