common/apiserver/general.go (227 lines of code) (raw):

package apiserver import ( "crypto/tls" "crypto/x509" "errors" "sync" "github.com/aliyun/aliyun_assist_client/agent/flagging" "github.com/aliyun/aliyun_assist_client/thirdparty/sirupsen/logrus" "go.uber.org/atomic" "github.com/aliyun/aliyun_assist_client/common/httpbase" "github.com/aliyun/aliyun_assist_client/common/httputil" "github.com/aliyun/aliyun_assist_client/common/metaserver" "github.com/aliyun/aliyun_assist_client/common/networkcategory" "github.com/aliyun/aliyun_assist_client/common/requester" ) const ( IntranetDomain = ".axt.aliyun.com" ) var ( errUnknownDetectionResponse = errors.New("Unknown connection detection response") wellKnownRegionIds = []string{ "cn-hangzhou", "cn-qingdao", "cn-beijing", "cn-zhangjiakou", "cn-huhehaote", "cn-shanghai", "cn-shenzhen", "cn-hongkong", "eu-west-1", } ) type GeneralProvider struct { serverDomain atomic.String extraHTTPHeadersProvider atomic.Value } type GeneralHTTPHeadersProvider struct { instanceIdHeaders map[string]string initInstanceIdHeadersOnce sync.Once } var ( generalHTTPHeadersProvider = GeneralHTTPHeadersProvider{} ) func (*GeneralProvider) Name() string { return "GeneralProvider" } func (p *GeneralProvider) ServerDomain(logger logrus.FieldLogger) (string, error) { // 1. Read region id cached in file if exists regionId, _ := regionidFileProvider.RegionId(logger) if regionId != "" { domain := regionId + IntranetDomain if err := ConnectionDetect(logger, domain); err == nil { p.serverDomain.Store(domain) go p.cacheRegionId(logger, regionId) networkcategory.Set(networkcategory.NetworkVPC) return domain, nil } else { logger.WithFields(logrus.Fields{ "domain": domain, }).WithError(err).Error("Failed on detection of API server connection") } } // 2. Retrieve region id from meta server in VPC network regionId, _ = metaserverProvider.RegionId(logger) if regionId != "" { domain := regionId + IntranetDomain if err := ConnectionDetect(logger, domain); err == nil { p.serverDomain.Store(domain) go p.cacheRegionId(logger, regionId) networkcategory.Set(networkcategory.NetworkVPC) return domain, nil } else { logger.WithFields(logrus.Fields{ "domain": domain, }).WithError(err).Error("Failed on detection of API server connection") } } // 3. Poll well-known API servers for region id if flagging.GetApiserverTryPreset() { for _, regionId := range wellKnownRegionIds { regionId, err := HttpGetWithoutExtraHeader(logger, "https://"+regionId+IntranetDomain+"/luban/api/classic/region-id") if err != nil { continue } domain := regionId + IntranetDomain if err := ConnectionDetect(logger, domain); err == nil { p.serverDomain.Store(domain) go p.cacheRegionId(logger, regionId) networkcategory.Set(networkcategory.NetworkClassic) return domain, nil } else { logger.WithFields(logrus.Fields{ "domain": domain, }).WithError(err).Error("Failed on detection of API server connection") } } } return "", requester.ErrNotProvided } func (p *GeneralProvider) ExtraHTTPHeaders(logger logrus.FieldLogger) (map[string]string, error) { if p.serverDomain.Load() == "" { return nil, requester.ErrNotProvided } epp := p.extraHTTPHeadersProvider.Load() if epp == nil { return generalHTTPHeadersProvider.ExtraHTTPHeaders(logger) } ep, ok := epp.(requester.ExtraHTTPHeadersProvider) if !ok { return generalHTTPHeadersProvider.ExtraHTTPHeaders(logger) } return ep.ExtraHTTPHeaders(logger) } func (*GeneralProvider) cacheRegionId(logger logrus.FieldLogger, regionId string) { requester.SetRegionId(regionId) regionidFileProvider.SaveRegionId(logger, regionId) } func (gp *GeneralHTTPHeadersProvider) ExtraHTTPHeaders(logger logrus.FieldLogger) (map[string]string, error) { gp.initInstanceIdHeadersOnce.Do(func() { instanceId := func() string { instanceId, err := metaserver.GetInstanceId(logger) if err != nil { return "unknown" } return instanceId }() gp.instanceIdHeaders = map[string]string{ "X-Client-Instance-ID": instanceId, } }) return gp.instanceIdHeaders, nil } func ConnectionDetect(logger logrus.FieldLogger, domain string) error { url := "https://" + domain + "/luban/api/connection_detect" content, err := HttpGetWithoutExtraHeader(logger, url) if err != nil { return err } if content == "ok" { return nil } return errUnknownDetectionResponse } func HttpGetWithoutExtraHeader(logger logrus.FieldLogger, url string) (string, error) { return HttpGetWithSpecifiedHeader(logger, url, nil) } func HttpGetWithSpecifiedHeader(logger logrus.FieldLogger, url string, headers map[string]string) (string, error) { logger = logger.WithField("url", url) transport := requester.GetHTTPTransport(logger) request := httputil.NewGetReq(logger, transport, 5, headers) response, err := request.Get(url) if err != nil { var certificateErr *tls.CertificateVerificationError if !errors.As(err, &certificateErr) { logger.WithError(err).Error("Failed to send HTTP GET request") return "", err } // tls.CertificateVerificationError encountered. Gonna re-accumulate root CA // certificate pool and retry requesting. logger.Info("certificate error, reload certificates and retry") transport = transport.Clone() requester.AccumulateRootCAs(logger)(func(certPool *x509.CertPool) bool { request = httputil.NewGetReq(logger, transport, 5, headers) request.SetTLSClient(&tls.Config{ RootCAs: certPool, }) if response, err = request.Get(url); err == nil { logger.Info("certificate updated") requester.RefreshHTTPCas(logger, certPool) return false } return true }) if err != nil { logger.WithError(err).Error("Failed to send HTTP GET request") return "", err } } defer response.Close() content, _ := response.Content() if err == nil && response.StatusCode() > 400 { err = httpbase.NewStatusCodeError(response.StatusCode()) } logger.WithFields(logrus.Fields{ "url": url, "responseCode": response.StatusCode(), "responseContent": content, }).WithError(err).Infoln("HTTP GET Requested") return content, err } func HttpPostWithSpecifiedHeader(logger logrus.FieldLogger, url string, data string, contentType string, headers map[string]string) (string, error) { logger = logger.WithField("url", url) transport := requester.GetHTTPTransport(logger) request := httputil.NewPostReq(logger, transport, contentType, 5, headers) response, err := request.Post(url, data) if err != nil { var certificateErr *tls.CertificateVerificationError if !errors.As(err, &certificateErr) { logger.WithError(err).Error("Failed to send HTTP POST request") return "", err } // tls.CertificateVerificationError encountered. Gonna re-accumulate root CA // certificate pool and retry requesting. logger.Info("certificate error, reload certificates and retry") transport = transport.Clone() requester.AccumulateRootCAs(logger)(func(certPool *x509.CertPool) bool { request = httputil.NewPostReq(logger, transport, contentType, 5, headers) request.SetTLSClient(&tls.Config{ RootCAs: certPool, }) if response, err = request.Post(url, data); err == nil { logger.Info("certificate updated") requester.RefreshHTTPCas(logger, certPool) return false } return true }) if err != nil { logger.WithError(err).Error("Failed to send HTTP POST request") return "", err } } defer response.Close() content, _ := response.Content() if err == nil && response.StatusCode() > 400 { err = httpbase.NewStatusCodeError(response.StatusCode()) } logger.WithFields(logrus.Fields{ "url": url, "responseCode": response.StatusCode(), "responseContent": content, }).WithError(err).Infoln("HTTP POST Requested") return content, err }