sdk/internal/recording/recording.go (569 lines of code) (raw):

//go:build go1.18 // +build go1.18 // Copyright (c) Microsoft Corporation. All rights reserved. // Licensed under the MIT License. package recording import ( "bytes" "crypto/tls" "crypto/x509" "encoding/base64" "encoding/json" "errors" "fmt" "io" "io/fs" "log" "math/rand" "net/http" "os" "os/exec" "path/filepath" "runtime" "strconv" "strings" "sync" "testing" "time" "github.com/Azure/azure-sdk-for-go/sdk/internal/uuid" ) // Deprecated: the local recording API that uses this type is no longer supported. Call [Start] and [Stop] // to make recordings via the test proxy instead. type Recording struct { SessionName string RecordingFile string VariablesFile string Mode RecordMode Sanitizer *Sanitizer Matcher *RequestMatcher } const ( alphanumericBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890" alphanumericLowercaseBytes = "abcdefghijklmnopqrstuvwxyz1234567890" randomSeedVariableName = "randomSeed" nowVariableName = "now" ModeEnvironmentVariableName = "AZURE_TEST_MODE" recordingAssetConfigName = "assets.json" ) // Inspired by https://stackoverflow.com/questions/22892120/how-to-generate-a-random-string-of-a-fixed-length-in-go const ( letterIdxBits = 6 // 6 bits to represent a letter index letterIdxMask = 1<<letterIdxBits - 1 // All 1-bits, as many as letterIdxBits letterIdxMax = 63 / letterIdxBits // # of letter indices fitting in 63 bits ) type RecordMode string const ( Record RecordMode = "record" Playback RecordMode = "playback" Live RecordMode = "live" ) // Deprecated: only deprecated methods use this type. Call [Start] and [Stop] to make recordings. type VariableType string const ( // NoSanitization indicates that the recorded value should not be sanitized. NoSanitization VariableType = "default" // Secret_String indicates that the recorded value should be replaced with a sanitized value. Secret_String VariableType = "secret_string" // Secret_Base64String indicates that the recorded value should be replaced with a sanitized valid base-64 string value. Secret_Base64String VariableType = "secret_base64String" ) var errUnsupportedAPI = errors.New("the vcr based test recording API isn't supported. Use the test proxy instead") // NewRecording initializes a new Recording instance func NewRecording(c TestContext, mode RecordMode) (*Recording, error) { return nil, errUnsupportedAPI } // GetEnvVar returns a recorded environment variable. If the variable is not found we return an error. // variableType determines how the recorded variable will be saved. func (r *Recording) GetEnvVar(name string, variableType VariableType) (string, error) { return "", errUnsupportedAPI } // GetOptionalEnvVar returns a recorded environment variable with a fallback default value. // default Value configures the fallback value to be returned if the environment variable is not set. // variableType determines how the recorded variable will be saved. func (r *Recording) GetOptionalEnvVar(name string, defaultValue string, variableType VariableType) string { panic(errUnsupportedAPI) } // Do satisfies the azcore.Transport interface so that Recording can be used as the transport for recorded requests func (r *Recording) Do(req *http.Request) (*http.Response, error) { return nil, errUnsupportedAPI } // Stop stops the recording and saves them, including any captured variables, to disk func (r *Recording) Stop() error { return errUnsupportedAPI } func (r *Recording) Now() time.Time { panic(errUnsupportedAPI) } func (r *Recording) UUID() uuid.UUID { panic(errUnsupportedAPI) } // GenerateAlphaNumericID will generate a recorded random alpha numeric id // if the recording has a randomSeed already set, the value will be generated from that seed, else a new random seed will be used func (r *Recording) GenerateAlphaNumericID(prefix string, length int, lowercaseOnly bool) (string, error) { return "", errUnsupportedAPI } func init() { recordMode = os.Getenv("AZURE_RECORD_MODE") if recordMode == "" { log.Printf("AZURE_RECORD_MODE was not set, defaulting to playback") recordMode = PlaybackMode } if !(recordMode == RecordingMode || recordMode == PlaybackMode || recordMode == LiveMode) { log.Panicf("AZURE_RECORD_MODE was not understood, options are %s, %s, or %s Received: %v.\n", RecordingMode, PlaybackMode, LiveMode, recordMode) } localFile, err := findProxyCertLocation() if err != nil { log.Println("Could not find the PROXY_CERT environment variable and was unable to locate the path in eng/common") } var certPool *x509.CertPool if runtime.GOOS == "windows" { certPool = x509.NewCertPool() } else { certPool, err = x509.SystemCertPool() if err != nil { log.Println("could not create a system cert pool") log.Panic(err.Error()) } } cert, err := os.ReadFile(localFile) if err != nil { log.Printf("could not read file set in PROXY_CERT variable at %s.\n", localFile) } if ok := certPool.AppendCertsFromPEM(cert); !ok { log.Println("no certs appended, using system certs only") } // Set a Default matcher that ignores :path, :scheme, :authority, and :method headers err = SetDefaultMatcher( nil, &SetDefaultMatcherOptions{ExcludedHeaders: []string{ ":authority", ":method", ":path", ":scheme", }}, ) if err != nil { log.Println("could not set the default matcher") } else { log.Println("default matcher was set ") } } var ( defaultPort = os.Getpid()%10000 + 20000 recordMode string rootCAs *x509.CertPool ) const ( RecordingMode = "record" PlaybackMode = "playback" LiveMode = "live" IDHeader = "x-recording-id" ModeHeader = "x-recording-mode" UpstreamURIHeader = "x-recording-upstream-base-uri" recordingRandSeedVarKey = "randSeed" ) type recordedTest struct { recordingId string liveOnly bool variables map[string]interface{} recordingSeed int64 recordingRandSrc rand.Source } // testMap maps test names to metadata type testMap struct { m *sync.Map } // Load returns the named test's metadata, if it has been stored func (t *testMap) Load(name string) (recordedTest, bool) { var rt recordedTest v, ok := t.m.Load(name) if ok { rt = v.(recordedTest) } return rt, ok } // Store sets metadata for the named test func (t *testMap) Store(name string, data recordedTest) { t.m.Store(name, data) } // Remove delete metadata for the named test func (t *testMap) Remove(name string) { t.m.Delete(name) } var testSuite = testMap{&sync.Map{}} var client = http.Client{ Transport: &http.Transport{ TLSClientConfig: &tls.Config{InsecureSkipVerify: true}, }, } type RecordingOptions struct { UseHTTPS bool // ProxyPort is the port the test proxy is listening on. Defaults to the port used by [StartTestProxy]. ProxyPort int GroupForReplace string Variables map[string]interface{} TestInstance *testing.T // insecure allows this package's tests to configure the proxy to skip upstream TLS // verification so they can use a mock upstream server having a self-signed cert insecure bool } func defaultOptions() *RecordingOptions { return &RecordingOptions{ UseHTTPS: true, ProxyPort: defaultPort, } } func (r RecordingOptions) ReplaceAuthority(t *testing.T, rawReq *http.Request) *http.Request { if GetRecordMode() != LiveMode && !IsLiveOnly(t) { originalURLHost := rawReq.URL.Host // don't modify the original request cp := *rawReq cpURL := *cp.URL cp.URL = &cpURL cp.Header = rawReq.Header.Clone() cp.URL.Scheme = r.scheme() cp.URL.Host = r.host() cp.Host = r.host() cp.Header.Set(UpstreamURIHeader, fmt.Sprintf("%v://%v", r.scheme(), originalURLHost)) cp.Header.Set(ModeHeader, GetRecordMode()) cp.Header.Set(IDHeader, GetRecordingId(t)) rawReq = &cp } return rawReq } func (r RecordingOptions) host() string { port := r.ProxyPort if port == 0 { port = defaultPort } return fmt.Sprintf("localhost:%d", port) } func (r RecordingOptions) scheme() string { if r.UseHTTPS { return "https" } return "http" } func (r RecordingOptions) baseURL() string { return fmt.Sprintf("%s://%s", r.scheme(), r.host()) } func getTestId(pathToRecordings string, t *testing.T) string { return filepath.Join(pathToRecordings, "recordings", t.Name()+".json") } func getGitRoot(fromPath string) (string, error) { absPath, err := filepath.Abs(fromPath) if err != nil { return "", err } cmd := exec.Command("git", "rev-parse", "--show-toplevel") cmd.Dir = absPath root, err := cmd.CombinedOutput() if err != nil { return "", fmt.Errorf("unable to find git root for path '%s'", absPath) } // Wrap with Abs() to get os-specific path separators to support sub-path matching return filepath.Abs(strings.TrimSpace(string(root))) } // Traverse up from a recording path until an asset config file is found. // Stop searching when the root of the git repository is reached. func findAssetsConfigFile(fromPath string, untilPath string) (string, error) { absPath, err := filepath.Abs(fromPath) if err != nil { return "", err } assetConfigPath := filepath.Join(absPath, recordingAssetConfigName) if _, err := os.Stat(assetConfigPath); err == nil { return assetConfigPath, nil } else if !errors.Is(err, fs.ErrNotExist) { return "", err } if absPath == untilPath { return "", nil } parentDir := filepath.Dir(absPath) // This shouldn't be hit due to checks in getGitRoot, but it can't hurt to be defensive if parentDir == absPath || parentDir == "." { return "", nil } return findAssetsConfigFile(parentDir, untilPath) } // Returns absolute and relative paths to an asset configuration file, or an error. func getAssetsConfigLocation(pathToRecordings string) (string, string, error) { cwd, err := os.Getwd() if err != nil { return "", "", err } gitRoot, err := getGitRoot(cwd) if err != nil { return "", "", err } abs, err := findAssetsConfigFile(filepath.Join(gitRoot, pathToRecordings), gitRoot) if err != nil { return "", "", err } // Pass a path relative to the git root to test proxy so that paths // can be resolved when the repo root is mounted as a volume in a container rel := strings.Replace(abs, gitRoot, "", 1) rel = strings.TrimLeft(rel, string(os.PathSeparator)) return abs, rel, nil } func requestStart(url string, testId string, assetConfigLocation string) (*http.Response, error) { req, err := http.NewRequest("POST", url, nil) if err != nil { return nil, err } req.Header.Set("Content-Type", "application/json") reqBody := map[string]string{"x-recording-file": testId} if assetConfigLocation != "" { reqBody["x-recording-assets-file"] = assetConfigLocation } marshalled, err := json.Marshal(reqBody) if err != nil { return nil, err } req.Body = io.NopCloser(bytes.NewReader(marshalled)) req.ContentLength = int64(len(marshalled)) return client.Do(req) } func Start(t *testing.T, pathToRecordings string, options *RecordingOptions) error { if recordMode == LiveMode { return nil } if testStruct, ok := testSuite.Load(t.Name()); ok { if testStruct.liveOnly { // test should only be run live, don't want to generate recording return nil } } if options == nil { options = defaultOptions() } testId := getTestId(pathToRecordings, t) absAssetLocation, relAssetLocation, err := getAssetsConfigLocation(pathToRecordings) if err != nil { return err } url := fmt.Sprintf("%s/%s/start", options.baseURL(), recordMode) var resp *http.Response if absAssetLocation == "" { resp, err = requestStart(url, testId, "") if err != nil { return err } } else if resp, err = requestStart(url, testId, absAssetLocation); err != nil { return err } else if resp.StatusCode >= 400 { if resp, err = requestStart(url, testId, relAssetLocation); err != nil { return err } } recId := resp.Header.Get(IDHeader) if recId == "" { b, err := io.ReadAll(resp.Body) defer resp.Body.Close() if err != nil { return err } return fmt.Errorf("Recording ID was not returned by the response. Response body: %s", b) } // Unmarshal any variables returned by the proxy var m map[string]interface{} body, err := io.ReadAll(resp.Body) defer resp.Body.Close() if err != nil { return err } if len(body) > 0 { err = json.Unmarshal(body, &m) if err != nil { return err } } if val, ok := testSuite.Load(t.Name()); ok { val.recordingId = recId val.variables = m testSuite.Store(t.Name(), val) } else { testSuite.Store(t.Name(), recordedTest{ recordingId: recId, liveOnly: false, variables: m, }) } return nil } // Stop tells the test proxy to stop accepting requests for a given test func Stop(t *testing.T, options *RecordingOptions) error { if options == nil { options = defaultOptions() } if recordMode == LiveMode { return nil } if testStruct, ok := testSuite.Load(t.Name()); ok { if testStruct.liveOnly { // test should only be run live, don't want to generate recording return nil } if testStruct.recordingSeed != 0 { if options.Variables == nil { options.Variables = map[string]interface{}{} } options.Variables[recordingRandSeedVarKey] = strconv.FormatInt(testStruct.recordingSeed, 10) } } url := fmt.Sprintf("%v/%v/stop", options.baseURL(), recordMode) req, err := http.NewRequest("POST", url, nil) if err != nil { return err } if len(options.Variables) > 0 { req.Header.Set("Content-Type", "application/json") marshalled, err := json.Marshal(options.Variables) if err != nil { return err } req.Body = io.NopCloser(bytes.NewReader(marshalled)) req.ContentLength = int64(len(marshalled)) } var recTest recordedTest var ok bool if recTest, ok = testSuite.Load(t.Name()); !ok { return errors.New("Recording ID was never set. Did you call StartRecording?") } req.Header.Set(IDHeader, recTest.recordingId) testSuite.Remove(t.Name()) resp, err := client.Do(req) if err != nil { return err } if resp.StatusCode != 200 { b, err := io.ReadAll(resp.Body) defer resp.Body.Close() if err == nil { return fmt.Errorf("proxy did not stop the recording properly: %s", string(b)) } return fmt.Errorf("proxy did not stop the recording properly: %s", err.Error()) } _ = resp.Body.Close() return err } func getRandomSource(t *testing.T) rand.Source { if testStruct, ok := testSuite.Load(t.Name()); ok { if testStruct.recordingRandSrc != nil { return testStruct.recordingRandSrc } } var seed int64 var err error variables := GetVariables(t) seedString, ok := variables[recordingRandSeedVarKey] if ok { seed, err = strconv.ParseInt(seedString.(string), 10, 64) } // We did not have a random seed already stored; create a new one if !ok || err != nil || GetRecordMode() == "live" { seed = time.Now().Unix() } source := rand.NewSource(seed) if testStruct, ok := testSuite.Load(t.Name()); ok { testStruct.recordingSeed = seed testStruct.recordingRandSrc = source testSuite.Store(t.Name(), testStruct) } return source } // GenerateAlphaNumericID will generate a recorded random alpha numeric id. // When live mode or the recording has a randomSeed already set, the value will be generated from that seed, else a new random seed will be used. func GenerateAlphaNumericID(t *testing.T, prefix string, length int, lowercaseOnly bool) (string, error) { return generateAlphaNumericID(prefix, length, lowercaseOnly, getRandomSource(t)) } func generateAlphaNumericID(prefix string, length int, lowercaseOnly bool, randomSource rand.Source) (string, error) { if length <= len(prefix) { return "", errors.New("length must be greater than prefix") } sb := strings.Builder{} sb.Grow(length) sb.WriteString(prefix) i := length - len(prefix) - 1 // A src.Int63() generates 63 random bits, enough for letterIdxMax characters! for cache, remain := randomSource.Int63(), letterIdxMax; i >= 0; { if remain == 0 { cache, remain = randomSource.Int63(), letterIdxMax } if lowercaseOnly { if idx := int(cache & letterIdxMask); idx < len(alphanumericLowercaseBytes) { sb.WriteByte(alphanumericLowercaseBytes[idx]) i-- } } else { if idx := int(cache & letterIdxMask); idx < len(alphanumericBytes) { sb.WriteByte(alphanumericBytes[idx]) i-- } } cache >>= letterIdxBits remain-- } str := sb.String() return str, nil } // GetEnvVariable looks up an environment variable and if it is not found, returns the recordedValue func GetEnvVariable(varName string, recordedValue string) string { val, ok := os.LookupEnv(varName) if !ok || GetRecordMode() == PlaybackMode { return recordedValue } return val } func LiveOnly(t *testing.T) { if val, ok := testSuite.Load(t.Name()); ok { val.liveOnly = true testSuite.Store(t.Name(), val) } else { testSuite.Store(t.Name(), recordedTest{liveOnly: true}) } if GetRecordMode() == PlaybackMode { t.Skip("Live Test Only") } } // Sleep during a test for `duration` seconds. This method will only execute when // AZURE_RECORD_MODE = "record", if a test is running in playback this will be a noop. func Sleep(duration time.Duration) { if GetRecordMode() != PlaybackMode { time.Sleep(duration) } } func GetRecordingId(t *testing.T) string { if val, ok := testSuite.Load(t.Name()); ok { return val.recordingId } else { return "" } } func GetRecordMode() string { return recordMode } func findProxyCertLocation() (string, error) { fileLocation, ok := os.LookupEnv("PROXY_CERT") if ok { return fileLocation, nil } out, err := exec.Command("git", "rev-parse", "--show-toplevel").Output() if err != nil { log.Print("Could not find PROXY_CERT environment variable or toplevel of git repository, please set PROXY_CERT to location of certificate found in eng/common/testproxy/dotnet-devcert.crt") return "", err } topLevel := bytes.NewBuffer(out).String() return filepath.Join(topLevel, "eng", "common", "testproxy", "dotnet-devcert.crt"), nil } type RecordingHTTPClient struct { defaultClient *http.Client options RecordingOptions t *testing.T } func (c RecordingHTTPClient) Do(req *http.Request) (*http.Response, error) { origScheme := req.URL.Scheme origHost := req.URL.Host req = c.options.ReplaceAuthority(c.t, req) resp, err := c.defaultClient.Do(req) if err != nil { return nil, err } // if the request succeeds, restore the scheme/host with their original values. // this is imporant for things like LROs that might use the originating URL to // poll for status and/or fetch the final result. resp.Request.URL.Scheme = origScheme resp.Request.URL.Host = origHost // if the response is a recording mismatch error from the proxy, return // its message as a simple error that prints legibly in test output if er := resp.Header.Get("x-request-mismatch-error"); er != "" { if msg, e := base64.StdEncoding.DecodeString(er); e == nil { err = errors.New(string(msg)) } } return resp, err } // NewRecordingHTTPClient returns a type that implements `azcore.Transporter`. This will automatically route tests on the `Do` call. func NewRecordingHTTPClient(t *testing.T, options *RecordingOptions) (*RecordingHTTPClient, error) { if options == nil { options = defaultOptions() } c, err := GetHTTPClient(t) if err != nil { return nil, err } return &RecordingHTTPClient{ defaultClient: c, options: *options, t: t, }, nil } func GetHTTPClient(t *testing.T) (*http.Client, error) { transport := http.DefaultTransport.(*http.Transport).Clone() transport.TLSClientConfig.RootCAs = rootCAs transport.TLSClientConfig.MinVersion = tls.VersionTLS12 transport.TLSClientConfig.InsecureSkipVerify = true defaultHttpClient := &http.Client{ Transport: transport, } return defaultHttpClient, nil } func IsLiveOnly(t *testing.T) bool { if s, ok := testSuite.Load(t.Name()); ok { return s.liveOnly } return false } // GetVariables returns access to the variables stored by the test proxy for a specific test func GetVariables(t *testing.T) map[string]interface{} { if s, ok := testSuite.Load(t.Name()); ok { return s.variables } return nil }