azkustodata/conn.go (216 lines of code) (raw):
package azkustodata
// Conn.go holds the connection to the Kusto server and provides methods to do queries
// and receive Kusto frames back.
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"strings"
"sync"
"sync/atomic"
"unicode"
"github.com/Azure/azure-kusto-go/azkustodata/errors"
"github.com/Azure/azure-kusto-go/azkustodata/internal/response"
truestedEndpoints "github.com/Azure/azure-kusto-go/azkustodata/trusted_endpoints"
"github.com/google/uuid"
)
var bufferPool = sync.Pool{
New: func() interface{} {
return &bytes.Buffer{}
},
}
// TODO - inspect this. Do we need this? can this be simplified?
// Conn provides connectivity to a Kusto instance.
type Conn struct {
endpoint string
auth Authorization
endMgmt, endQuery, endStreamIngest *url.URL
client *http.Client
endpointValidated atomic.Bool
clientDetails *ClientDetails
}
// NewConn returns a new Conn object with an injected http.Client
func NewConn(endpoint string, auth Authorization, client *http.Client, clientDetails *ClientDetails) (*Conn, error) {
u, err := url.Parse(endpoint)
if err != nil {
return nil, errors.ES(errors.OpServConn, errors.KClientArgs, "could not parse the endpoint(%s): %s", endpoint, err).SetNoRetry()
}
if endpoint == "" {
return nil, errors.ES(errors.OpQuery, errors.KClientArgs, "endpoint cannot be empty")
}
if (u.Scheme != "https") && auth.TokenProvider.AuthorizationRequired() {
return nil, errors.ES(errors.OpServConn, errors.KClientArgs, "cannot use token provider with http endpoint, as it would send the token in clear text").SetNoRetry()
}
if !strings.HasPrefix(u.Path, "/") {
u.Path = "/" + u.Path
}
c := &Conn{
auth: auth,
endMgmt: u.JoinPath("/v1/rest/mgmt"),
endQuery: u.JoinPath("/v2/rest/query"),
endStreamIngest: u.JoinPath("/v1/rest/ingest"),
client: client,
clientDetails: clientDetails,
endpoint: endpoint,
}
return c, nil
}
type queryMsg struct {
DB string `json:"db"`
CSL string `json:"csl"`
Properties requestProperties `json:"properties,omitempty"`
}
type connOptions struct {
queryOptions *queryOptions
}
func (c *Conn) rawQuery(ctx context.Context, callType callType, db string, query Statement, options *queryOptions) (io.ReadCloser, error) {
_, _, _, body, e := c.doRequest(ctx, int(callType), db, query, *options.requestProperties)
if e != nil {
return nil, e
}
return body, nil
}
const (
execQuery = 1
execMgmt = 2
)
func (c *Conn) doRequest(ctx context.Context, execType int, db string, query Statement, properties requestProperties) (errors.Op, http.Header, http.Header,
io.ReadCloser, error) {
var op errors.Op
err := c.validateEndpoint()
if err != nil {
op = errors.OpQuery
return 0, nil, nil, nil, errors.E(op, errors.KInternal, fmt.Errorf("could not validate endpoint: %w", err))
}
if execType == execQuery {
op = errors.OpQuery
} else if execType == execMgmt {
op = errors.OpMgmt
}
var endpoint *url.URL
buff := bufferPool.Get().(*bytes.Buffer)
buff.Reset()
defer bufferPool.Put(buff)
switch execType {
case execQuery, execMgmt:
var err error
var csl string
if query.SupportsInlineParameters() || properties.QueryParameters.Count() == 0 {
csl = query.String()
} else {
csl = fmt.Sprintf("%s\n%s", properties.QueryParameters.ToDeclarationString(), query.String())
}
err = json.NewEncoder(buff).Encode(
queryMsg{
DB: db,
CSL: csl,
Properties: properties,
},
)
if err != nil {
return 0, nil, nil, nil, errors.E(op, errors.KInternal, fmt.Errorf("could not JSON marshal the Query message: %w", err))
}
if execType == execQuery {
endpoint = c.endQuery
} else {
endpoint = c.endMgmt
}
default:
return 0, nil, nil, nil, errors.ES(op, errors.KInternal, "internal error: did not understand the type of execType: %d", execType)
}
headers := c.getHeaders(properties)
responseHeaders, closer, err := c.doRequestImpl(ctx, op, endpoint, io.NopCloser(buff), headers, fmt.Sprintf("With query: %s", query.String()))
return op, headers, responseHeaders, closer, err
}
func (c *Conn) doRequestImpl(
ctx context.Context,
op errors.Op,
endpoint *url.URL,
buff io.ReadCloser,
headers http.Header,
errorContext string) (http.Header, io.ReadCloser, error) {
// Replace non-ascii chars in headers with '?'
for _, values := range headers {
var builder strings.Builder
for i := range values {
for _, char := range values[i] {
if char > unicode.MaxASCII {
builder.WriteRune('?')
} else {
builder.WriteRune(char)
}
}
values[i] = builder.String()
}
}
if c.auth.TokenProvider != nil && c.auth.TokenProvider.AuthorizationRequired() {
c.auth.TokenProvider.SetHttp(c.client)
token, tokenType, tkerr := c.auth.TokenProvider.AcquireToken(ctx)
if tkerr != nil {
return nil, nil, errors.ES(op, errors.KInternal, "Error while getting token : %s", tkerr)
}
headers.Add("Authorization", fmt.Sprintf("%s %s", tokenType, token))
}
req := &http.Request{
Method: http.MethodPost,
URL: endpoint,
Header: headers,
Body: buff,
}
resp, err := c.client.Do(req.WithContext(ctx))
if err != nil {
// TODO(jdoak): We need a http error unwrap function that pulls out an *errors.Error.
return nil, nil, errors.E(op, errors.KHTTPError, fmt.Errorf("%v, %w", errorContext, err))
}
body, err := response.TranslateBody(resp, op)
if err != nil {
return nil, nil, err
}
if resp.StatusCode != http.StatusOK {
return nil, nil, errors.HTTP(op, resp.Status, resp.StatusCode, body, fmt.Sprintf("error from Kusto endpoint, %v", errorContext))
}
return resp.Header, body, nil
}
func (c *Conn) validateEndpoint() error {
if !c.endpointValidated.Load() {
var err error
if cloud, err := GetMetadata(c.endpoint, c.client); err == nil {
err = truestedEndpoints.Instance.ValidateTrustedEndpoint(c.endpoint, cloud.LoginEndpoint)
if err == nil {
c.endpointValidated.Store(true)
}
}
return err
}
return nil
}
const ClientRequestIdHeader = "x-ms-client-request-id"
const ApplicationHeader = "x-ms-app"
const UserHeader = "x-ms-user"
const ClientVersionHeader = "x-ms-client-version"
func (c *Conn) getHeaders(properties requestProperties) http.Header {
header := http.Header{}
header.Add("Accept", "application/json")
header.Add("Accept-Encoding", "gzip, deflate")
header.Add("Content-Type", "application/json; charset=utf-8")
header.Add("Connection", "Keep-Alive")
header.Add("x-ms-version", "2024-12-12")
if properties.ClientRequestID != "" {
header.Add(ClientRequestIdHeader, properties.ClientRequestID)
} else {
header.Add(ClientRequestIdHeader, "KGC.execute;"+uuid.New().String())
}
if properties.Application != "" {
header.Add(ApplicationHeader, properties.Application)
} else {
header.Add(ApplicationHeader, c.clientDetails.ApplicationForTracing())
}
if properties.User != "" {
header.Add(UserHeader, properties.User)
} else {
header.Add(UserHeader, c.clientDetails.UserNameForTracing())
}
header.Add(ClientVersionHeader, c.clientDetails.ClientVersionForTracing())
return header
}
func (c *Conn) Close() error {
c.client.CloseIdleConnections()
return nil
}