odps/restclient/rest_client.go (254 lines of code) (raw):

// Licensed to the Apache Software Foundation (ASF) under one // or more contributor license agreements. See the NOTICE file // distributed with this work for additional information // regarding copyright ownership. The ASF licenses this file // to you under the Apache License, Version 2.0 (the // "License"); you may not use this file except in compliance // with the License. You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. package restclient import ( "bytes" "encoding/xml" "fmt" "io" "log" "net" "net/http" "net/url" "os" "strings" "time" "github.com/pkg/errors" "github.com/aliyun/aliyun-odps-go-sdk/odps/account" "github.com/aliyun/aliyun-odps-go-sdk/odps/common" ) // Todo 请求方法需要重构,加入header参数 const ( DefaultHttpTimeout = 30 DefaultTcpConnectionTimeout = 30 ) type RestClient struct { account.Account // It is the total time from the tcp connection to http response. // the default value is 0, represents no timeout HttpTimeout time.Duration TcpConnectionTimeout time.Duration DnsCacheExpireTime time.Duration DisableCompression bool _client *http.Client defaultProject string currentSchema string endpoint string userAgent string } func NewOdpsRestClient(a account.Account, endpoint string) RestClient { client := RestClient{ Account: a, endpoint: endpoint, HttpTimeout: DefaultHttpTimeout * time.Second, TcpConnectionTimeout: DefaultTcpConnectionTimeout * time.Second, DnsCacheExpireTime: time.Duration(DefaultDNSCacheExpireTime) * time.Second, DisableCompression: true, } return client } func LoadEndpointFromEnv() string { endpoint, _ := os.LookupEnv("odps_endpoint") return endpoint } func (client *RestClient) SetDefaultProject(projectName string) { client.defaultProject = projectName } func (client *RestClient) SetCurrentSchema(schemaName string) { client.currentSchema = schemaName } func (client *RestClient) SetUserAgent(userAgent string) { client.userAgent = userAgent } func (client *RestClient) UserAgent() string { if client.userAgent != "" { return client.userAgent } return common.UserAgentValue } func (client *RestClient) Endpoint() string { return client.endpoint } func (client *RestClient) client() *http.Client { if client._client != nil { return client._client } resolver := NewResolver(int64(client.DnsCacheExpireTime) / int64(time.Second)) dialer := Dialer{ Resolver: resolver, Dialer: net.Dialer{ Timeout: client.TcpConnectionTimeout, KeepAlive: 30 * time.Second, }, } transport := http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: dialer.DialContext, ForceAttemptHTTP2: false, DisableCompression: client.DisableCompression, } client._client = &http.Client{ Transport: &transport, Timeout: client.HttpTimeout, } return client._client } func (client *RestClient) NewRequest(method, resource string, body io.Reader) (*http.Request, error) { urlStr := fmt.Sprintf( "%s/%s", strings.TrimRight(client.Endpoint(), "/"), strings.TrimLeft(resource, "/")) req, err := http.NewRequest(method, urlStr, body) return req, errors.WithStack(err) } func (client *RestClient) NewRequestWithUrlQuery(method, resource string, body io.Reader, queryArgs url.Values) (*http.Request, error) { req, err := client.NewRequest(method, resource, body) if err != nil { return nil, errors.WithStack(err) } if queryArgs != nil { req.URL.RawQuery = queryArgs.Encode() } return req, nil } func (client *RestClient) NewRequestWithParamsAndHeaders(method, resource string, body io.Reader, params url.Values, headers map[string]string) (*http.Request, error) { req, err := client.NewRequestWithUrlQuery(method, resource, body, params) if err != nil { return nil, err } if headers != nil { for name, value := range headers { req.Header.Set(name, value) } } return req, nil } func (client *RestClient) Do(req *http.Request) (*http.Response, error) { req.Header.Set(common.HttpHeaderUserAgent, common.UserAgentValue) req.Header.Set(common.HttpHeaderXOdpsUserAgent, client.UserAgent()) gmtTime := time.Now().In(common.GMT).Format(time.RFC1123) req.Header.Set(common.HttpHeaderDate, gmtTime) query := req.URL.Query() _, ok := query["current_project"] // in go1.17, 下面的语句应该这样写:if !query.Has("curr_project") && client.defaultProject != "" { // 但values.Has方法是在go1.17才引入的,为了兼容go1.15,所以不用Has方法 if !ok && client.defaultProject != "" { query.Set("curr_project", client.defaultProject) } req.URL.RawQuery = query.Encode() err := client.SignRequest(req, client.endpoint) if err != nil { return nil, err } res, err := client.client().Do(req) return res, errors.WithStack(err) } func (client *RestClient) DoWithParseFunc(req *http.Request, parseFunc func(res *http.Response) error) error { return client.DoWithParseRes(req, func(res *http.Response) error { if res.StatusCode < 200 || res.StatusCode >= 300 { return errors.WithStack(NewHttpNotOk(res)) } if parseFunc == nil { return nil } return errors.WithStack(parseFunc(res)) }) } func (client *RestClient) DoWithParseRes(req *http.Request, parseFunc func(res *http.Response) error) error { res, err := client.Do(req) if err != nil { return errors.WithStack(err) } defer func(Body io.ReadCloser) { err := Body.Close() if err != nil { log.Fatalf("close http error, url=%s", req.URL.String()) } }(res.Body) if parseFunc == nil { return nil } return errors.WithStack(parseFunc(res)) } func (client *RestClient) DoWithModel(req *http.Request, model interface{}) error { parseFunc := func(res *http.Response) error { decoder := xml.NewDecoder(res.Body) return errors.WithStack(decoder.Decode(model)) } return errors.WithStack(client.DoWithParseFunc(req, parseFunc)) } func (client *RestClient) GetWithModel(resource string, queryArgs url.Values, headers map[string]string, model interface{}) error { req, err := client.NewRequestWithParamsAndHeaders(common.HttpMethod.GetMethod, resource, nil, queryArgs, headers) if err != nil { return errors.WithStack(err) } return errors.WithStack(client.DoWithModel(req, model)) } func (client *RestClient) GetWithParseFunc(resource string, queryArgs url.Values, headers map[string]string, parseFunc func(res *http.Response) error) error { req, err := client.NewRequestWithParamsAndHeaders(common.HttpMethod.GetMethod, resource, nil, queryArgs, headers) if err != nil { return errors.WithStack(err) } return errors.WithStack(client.DoWithParseFunc(req, parseFunc)) } func (client *RestClient) PutWithParseFunc(resource string, queryArgs url.Values, body io.Reader, parseFunc func(res *http.Response) error) error { req, err := client.NewRequestWithUrlQuery(common.HttpMethod.PutMethod, resource, body, queryArgs) if err != nil { return errors.WithStack(err) } return errors.WithStack(client.DoWithParseFunc(req, parseFunc)) } func (client *RestClient) DoXmlWithParseFunc( method string, resource string, queryArgs url.Values, headers map[string]string, bodyModel interface{}, parseFunc func(res *http.Response) error, ) error { bodyXml, err := xml.Marshal(bodyModel) if err != nil { return errors.WithStack(err) } req, err := client.NewRequestWithUrlQuery(method, resource, bytes.NewReader(bodyXml), queryArgs) req.Header.Set(common.HttpHeaderContentType, common.XMLContentType) for name, value := range headers { req.Header.Set(name, value) } if err != nil { return errors.WithStack(err) } return errors.WithStack(client.DoWithParseFunc(req, parseFunc)) } func (client *RestClient) DoXmlWithParseRes( method string, resource string, queryArgs url.Values, headers map[string]string, bodyModel interface{}, parseFunc func(res *http.Response) error, ) error { bodyXml, err := xml.Marshal(bodyModel) if err != nil { return errors.WithStack(err) } req, err := client.NewRequestWithUrlQuery(method, resource, bytes.NewReader(bodyXml), queryArgs) req.Header.Set(common.HttpHeaderContentType, common.XMLContentType) for name, value := range headers { req.Header.Set(name, value) } if err != nil { return errors.WithStack(err) } return errors.WithStack(client.DoWithParseRes(req, parseFunc)) } func (client *RestClient) DoXmlWithModel( method string, resource string, queryArgs url.Values, bodyModel interface{}, resModel interface{}, ) error { parseFunc := func(res *http.Response) error { decoder := xml.NewDecoder(res.Body) if resModel == nil { return nil } return errors.WithStack(decoder.Decode(resModel)) } err := client.DoXmlWithParseFunc(method, resource, queryArgs, nil, bodyModel, parseFunc) return errors.WithStack(err) }