internal/clients/live_traffic_log_policy.go (102 lines of code) (raw):
package clients
import (
"encoding/json"
"io"
"log"
"net/http"
"strings"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
)
const redactedValue = "REDACTED"
type liveTrafficLogPolicy struct {
notAllowedHeaders map[string]bool
}
type traffic struct {
LiveRequest liveRequest `json:"request"`
LiveResponse liveResponse `json:"response"`
}
type liveRequest struct {
Headers map[string]string `json:"headers"`
Method string `json:"method"`
Url string `json:"url"`
Body string `json:"body"`
}
type liveResponse struct {
StatusCode int `json:"statusCode"`
Headers map[string]string `json:"headers"`
Body string `json:"body"`
}
func NewLiveTrafficLogPolicy() policy.Policy {
return &liveTrafficLogPolicy{
notAllowedHeaders: map[string]bool{
"authorization": true,
},
}
}
func (p *liveTrafficLogPolicy) Do(req *policy.Request) (*http.Response, error) {
rawRequest := req.Raw()
liveReq := liveRequest{
Headers: p.header(rawRequest.Header),
Method: rawRequest.Method,
Url: rawRequest.URL.String(),
Body: p.requestBodyString(req),
}
if err := req.RewindBody(); err != nil {
return nil, err
}
response, err := req.Next() // Make the request
liveResp := liveResponse{}
if err == nil {
liveResp.Headers = p.header(response.Header)
liveResp.StatusCode = response.StatusCode
liveResp.Body = p.responseBodyString(response)
} else {
liveResp.Body = err.Error()
}
liveTraffic := traffic{
LiveRequest: liveReq,
LiveResponse: liveResp,
}
data, marshalErr := json.Marshal(liveTraffic)
if marshalErr != nil {
log.Printf("[ERROR] Failed to marshal live traffic: %v", marshalErr)
return response, err
}
log.Printf("[DEBUG] Live traffic: %s", string(data))
return response, err
}
func (p *liveTrafficLogPolicy) requestBodyString(req *policy.Request) string {
if req.Raw().Body == nil {
return ""
}
body, err := io.ReadAll(req.Raw().Body)
if err != nil {
log.Printf("[ERROR] Failed to read request body: %v", err)
body = []byte(err.Error())
}
if err := req.RewindBody(); err != nil {
log.Printf("[ERROR] Failed to rewind request body: %v", err)
return ""
}
return string(body)
}
func (p *liveTrafficLogPolicy) responseBodyString(resp *http.Response) string {
body, err := runtime.Payload(resp)
if err != nil {
log.Printf("[ERROR] Failed to read response body: %v", err)
return ""
}
return string(body)
}
func (p *liveTrafficLogPolicy) header(input http.Header) map[string]string {
output := make(map[string]string)
for k, v := range input {
if p.notAllowedHeaders[strings.ToLower(k)] {
output[k] = redactedValue
} else {
output[k] = strings.Join(v, ",")
}
}
return output
}