pkg/middleware/arm_error_collector.go (185 lines of code) (raw):

package middleware import ( "context" "crypto/tls" "errors" "fmt" "net/http" "net/http/httptrace" "time" "github.com/Azure/azure-sdk-for-go/sdk/azcore" "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime" "github.com/Azure/azure-sdk-for-go/sdk/azcore/to" ) const ( headerKeyRequestID = "X-Ms-Client-Request-Id" headerKeyCorrelationID = "X-Ms-Correlation-Request-id" ArmErrorCodeCastToArmResponseErrorFailed ArmErrorCode = "CastToArmResponseErrorFailed" ArmErrorCodeTransportError ArmErrorCode = "TransportError" ArmErrorCodeUnexpectedTransportError ArmErrorCode = "UnexpectedTransportError" ArmErrorCodeContextCanceled ArmErrorCode = "ContextCanceled" ArmErrorCodeContextDeadlineExceeded ArmErrorCode = "ContextDeadlineExceeded" ) // ArmError is unified Error Experience across AzureResourceManager, it contains Code Message. type ArmError struct { Code ArmErrorCode `json:"code"` Message string `json:"message"` } type ArmErrorCode string type RequestInfo struct { Request *http.Request ArmResId *arm.ResourceID } func newRequestInfo(req *http.Request, resId *arm.ResourceID) *RequestInfo { return &RequestInfo{Request: req, ArmResId: resId} } type ResponseInfo struct { Response *http.Response Error *ArmError Latency time.Duration RequestId string CorrelationId string ConnTracking *HttpConnTracking } type HttpConnTracking struct { TotalLatency string DnsLatency string ConnLatency string TlsLatency string Protocol string ReqConnInfo *httptrace.GotConnInfo } // ArmRequestMetricCollector is a interface that collectors need to implement. // TODO: use *policy.Request or *http.Request? type ArmRequestMetricCollector interface { // RequestStarted is called when a request is about to be sent. // context is not provided, get it from RequestInfo.Request.Context() RequestStarted(*RequestInfo) // RequestCompleted is called when a request is finished // context is not provided, get it from RequestInfo.Request.Context() // if an error occurred, ResponseInfo.Error will be populated RequestCompleted(*RequestInfo, *ResponseInfo) } // ArmRequestMetricPolicy is a policy that collects metrics/telemetry for ARM requests. type ArmRequestMetricPolicy struct { Collector ArmRequestMetricCollector } // Do implements the azcore/policy.Policy interface. func (p *ArmRequestMetricPolicy) Do(req *policy.Request) (*http.Response, error) { httpReq := req.Raw() if httpReq == nil || httpReq.URL == nil { // not able to collect telemetry, just pass through return req.Next() } armResId, err := arm.ParseResourceID(httpReq.URL.Path) if err != nil { // TODO: error handling without break the request. } connTracking := &HttpConnTracking{} // have to add to the context at first - then clone the policy.Request struct // this allows the connection tracing to happen // otherwise we can't change the underlying http request of req, we have to use // newARMReq newCtx := addConnectionTracingToRequestContext(httpReq.Context(), connTracking) newARMReq := req.Clone(newCtx) requestInfo := newRequestInfo(httpReq, armResId) started := time.Now() p.requestStarted(requestInfo) var resp *http.Response var reqErr error // defer this function in case there's a panic somewhere down the pipeline. // It's the calling user's responsibility to handle panics, not this policy defer func() { latency := time.Since(started) respInfo := &ResponseInfo{ Response: resp, Latency: latency, ConnTracking: connTracking, } if reqErr != nil { // either it's a transport error // or it is already handled by previous policy respInfo.Error = parseTransportError(reqErr) } else { respInfo.Error = parseArmErrorFromResponse(resp) } // need to get the request id and correlation id from the response.request header // because the headers were added by policy and might be called after this policy if resp != nil && resp.Request != nil { respInfo.RequestId = resp.Request.Header.Get(headerKeyRequestID) respInfo.CorrelationId = resp.Request.Header.Get(headerKeyCorrelationID) } p.requestCompleted(requestInfo, respInfo) }() resp, reqErr = newARMReq.Next() return resp, reqErr } // shortcut function to handle nil collector func (p *ArmRequestMetricPolicy) requestStarted(iReq *RequestInfo) { if p.Collector != nil { p.Collector.RequestStarted(iReq) } } // shortcut function to handle nil collector func (p *ArmRequestMetricPolicy) requestCompleted(iReq *RequestInfo, iResp *ResponseInfo) { if p.Collector != nil { p.Collector.RequestCompleted(iReq, iResp) } } func parseArmErrorFromResponse(resp *http.Response) *ArmError { if resp == nil { return &ArmError{Code: ArmErrorCodeUnexpectedTransportError, Message: "nil response"} } if resp.StatusCode > 399 { // for 4xx, 5xx response, ARM should include {error:{code, message}} in the body err := runtime.NewResponseError(resp) respErr := &azcore.ResponseError{} if errors.As(err, &respErr) { return &ArmError{Code: ArmErrorCode(respErr.ErrorCode), Message: respErr.Error()} } return &ArmError{Code: ArmErrorCodeCastToArmResponseErrorFailed, Message: fmt.Sprintf("Response body is not in ARM error form: {error:{code, message}}: %s", err.Error())} } return nil } // distinguash // - Context Cancelled (request configured context to have timeout) // - ClientTimeout (context still valid, http client have timeout configured) // - Transport Error (DNS/Dial/TLS/ServerTimeout) func parseTransportError(err error) *ArmError { if err == nil { return nil } if errors.Is(err, context.Canceled) { return &ArmError{Code: ArmErrorCodeContextCanceled, Message: err.Error()} } if errors.Is(err, context.DeadlineExceeded) { return &ArmError{Code: ArmErrorCodeContextDeadlineExceeded, Message: err.Error()} } return &ArmError{Code: ArmErrorCodeTransportError, Message: err.Error()} } func addConnectionTracingToRequestContext(ctx context.Context, connTracking *HttpConnTracking) context.Context { var getConn, dnsStart, connStart, tlsStart *time.Time trace := &httptrace.ClientTrace{ GetConn: func(hostPort string) { getConn = to.Ptr(time.Now()) }, GotConn: func(connInfo httptrace.GotConnInfo) { if getConn != nil { connTracking.TotalLatency = fmt.Sprintf("%dms", time.Now().Sub(*getConn).Milliseconds()) } connTracking.ReqConnInfo = &connInfo }, DNSStart: func(_ httptrace.DNSStartInfo) { dnsStart = to.Ptr(time.Now()) }, DNSDone: func(dnsInfo httptrace.DNSDoneInfo) { if dnsInfo.Err == nil { if dnsStart != nil { connTracking.DnsLatency = fmt.Sprintf("%dms", time.Now().Sub(*dnsStart).Milliseconds()) } } else { connTracking.DnsLatency = dnsInfo.Err.Error() } }, ConnectStart: func(_, _ string) { connStart = to.Ptr(time.Now()) }, ConnectDone: func(_, _ string, err error) { if err == nil { if connStart != nil { connTracking.ConnLatency = fmt.Sprintf("%dms", time.Now().Sub(*connStart).Milliseconds()) } } else { connTracking.ConnLatency = err.Error() } }, TLSHandshakeStart: func() { tlsStart = to.Ptr(time.Now()) }, TLSHandshakeDone: func(t tls.ConnectionState, err error) { if err == nil { if tlsStart != nil { connTracking.TlsLatency = fmt.Sprintf("%dms", time.Now().Sub(*tlsStart).Milliseconds()) } connTracking.Protocol = t.NegotiatedProtocol } else { connTracking.TlsLatency = err.Error() } }, } ctx = httptrace.WithClientTrace(ctx, trace) return ctx }