daisy_test_runner/main.go (590 lines of code) (raw):

// Copyright 2018 Google Inc. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // daisy_test_runner is a tool for testing using Daisy workflows. package main import ( "bytes" "context" "encoding/json" "encoding/xml" "errors" "flag" "fmt" "io/ioutil" "log" "math/rand" "os" "os/signal" "path/filepath" "regexp" "strings" "sync" "text/template" "time" daisy "github.com/GoogleCloudPlatform/compute-daisy" daisyCompute "github.com/GoogleCloudPlatform/compute-daisy/compute" "github.com/google/uuid" "google.golang.org/api/compute/v1" ) const ( defaultParallelCount = 5 timeFormat = time.RFC3339 ) var ( oauth = flag.String("oauth", "", "path to oauth json file") projects = flag.String("projects", "", "comma separated list of projects that can be used for tests, overrides setting in template") zone = flag.String("zone", "", "zone to use for tests, overrides setting in template") print = flag.Bool("print", false, "print out the parsed test cases for debugging") printTemplate = flag.Bool("print_template", false, "print out the parsed test template for debugging") validate = flag.Bool("validate", false, "validate all the test cases and exit") ce = flag.String("compute_endpoint_override", "", "API endpoint to override default, will override ComputeEndpoint in template") filter = flag.String("filter", "", "test name filter") outPath = flag.String("out_path", "junit.xml", "junit xml path") parallelCount = flag.Int("parallel_count", 0, "TestParallelCount") funcMap = map[string]interface{}{ "randItem": randItem, "mkSlice": mkSlice, "mkMap": mkMap, "split": strings.Split, "add": func(i, a int) int { return i + a }, } testTemplate = template.New("testTemplate").Option("missingkey=zero").Funcs(funcMap) ) func randItem(args []string) string { rand.Seed(time.Now().UnixNano()) return args[rand.Intn(len(args))] } func mkSlice(args ...string) []string { return args } func mkMap(args ...string) map[string]string { m := make(map[string]string) for _, arg := range args { split := strings.Split(arg, ":") if len(split) != 2 { continue } m[split[0]] = split[1] } return m } // A TestSuite describes the tests to run. type TestSuite struct { // Name for this set of tests. Name string // Project pool to use. Projects []string // Default zone to use. Zone string // The test cases to run. Tests map[string]*TestCase // How many tests to run in parallel. TestParallelCount int OAuthPath string ComputeEndpoint string } // A TestCase is a single test to run. type TestCase struct { // Path to the daisy workflow to use. // Each test workflow should manage its own resource creation and cleanup. Path string w *daisy.Workflow id string logger *logger // Vars to pass to the daisy workflow. Vars map[string]string // Default timeout is 2 hours. // Must be parsable by https://golang.org/pkg/time/#ParseDuration. TestTimeout string timeout time.Duration // Optional settings that will override those set in the workflow or TestTemplate. Zone string OAuthPath string ComputeEndpoint string // If set this test will be the only test allowed to run in the project. // This is required for any test that changes project level settings that may // impact other concurrent test runs. ProjectLock bool CustomProjectLock string } type logger struct { buf bytes.Buffer mx sync.Mutex } func (l *logger) AppendSerialPortLogs(w *daisy.Workflow, instance string, logs string) { // no-op } func (l *logger) WriteSerialPortLogsToCloudLogging(w *daisy.Workflow, instance string) { // no-op } func (l *logger) WriteLogEntry(e *daisy.LogEntry) { l.mx.Lock() defer l.mx.Unlock() l.buf.WriteString(e.String()) } func (l *logger) ReadSerialPortLogs() []string { return nil } func (l *logger) Flush() { return } func createTestCase(ctx context.Context, testLogger *logger, path, project, zone, oauthPath, ce string, varMap map[string]string) (*daisy.Workflow, error) { w, err := daisy.NewFromFile(path) if err != nil { return nil, err } for k, v := range varMap { w.AddVar(k, v) } if oauthPath != "" { w.OAuthPath = oauthPath } if ce != "" { w.ComputeEndpoint = ce } if err := w.PopulateClients(ctx); err != nil { return nil, err } w.Project = project w.Zone = zone w.DisableGCSLogging() w.DisableCloudLogging() w.DisableStdoutLogging() w.Logger = testLogger if len(w.Steps) == 0 { return nil, nil } return w, nil } func createTestSuite(ctx context.Context, path string, varMap map[string]string, regex *regexp.Regexp) (*TestSuite, error) { var t TestSuite b, err := ioutil.ReadFile(path) if err != nil { return nil, fmt.Errorf("%s: %v", path, err) } templ, err := testTemplate.Parse(string(b)) if err != nil { return nil, fmt.Errorf("%s: %v", path, err) } var buf bytes.Buffer if err := templ.Execute(&buf, varMap); err != nil { return nil, fmt.Errorf("%s: %v", path, err) } if *printTemplate { fmt.Println(buf.String()) return nil, nil } if err := json.Unmarshal(buf.Bytes(), &t); err != nil { return nil, daisy.JSONError(path, buf.Bytes(), err) } if *projects != "" { t.Projects = strings.Split(*projects, ",") } if len(t.Projects) == 0 { return nil, errors.New("no projects provided") } if *zone != "" { t.Zone = *zone } if *oauth != "" { t.OAuthPath = *oauth } if *ce != "" { t.ComputeEndpoint = *ce } if *parallelCount != 0 { t.TestParallelCount = *parallelCount } if t.TestParallelCount == 0 { t.TestParallelCount = defaultParallelCount } fmt.Printf("[TestRunner] Creating test cases for test suite %q\n", t.Name) for name, test := range t.Tests { test.id = uuid.New().String() if test.TestTimeout == "" { test.timeout = defaultTimeout } else { d, err := time.ParseDuration(test.TestTimeout) if err != nil { test.timeout = defaultTimeout } else { test.timeout = d } } if regex != nil && !regex.MatchString(name) { continue } fmt.Printf(" - Creating test case for %q\n", name) wfPath := filepath.Join(filepath.Dir(path), test.Path) for k, v := range test.Vars { varMap[k] = v } zone := t.Zone if test.Zone != "" { zone = test.Zone } oauthPath := t.OAuthPath if test.OAuthPath != "" { oauthPath = test.OAuthPath } computeEndpoint := t.ComputeEndpoint if test.ComputeEndpoint != "" { computeEndpoint = test.ComputeEndpoint } rand.Seed(time.Now().UnixNano()) test.logger = &logger{} w, err := createTestCase(ctx, test.logger, wfPath, t.Projects[rand.Intn(len(t.Projects))], zone, oauthPath, computeEndpoint, varMap) if err != nil { return nil, err } test.w = w } return &t, nil } const ( flgDefValue = "flag generated for workflow variable" varFlagPrefix = "var:" ) func addFlags(args []string) { for _, arg := range args { if len(arg) <= 1 || arg[0] != '-' { continue } name := arg[1:] if name[0] == '-' { name = name[1:] } if !strings.HasPrefix(name, varFlagPrefix) { continue } name = strings.SplitN(name, "=", 2)[0] if flag.Lookup(name) != nil { continue } flag.String(name, "", flgDefValue) } } func checkError(errors chan error) { select { case err := <-errors: fmt.Fprintln(os.Stderr, "\n[TestRunner] Errors in one or more test cases:") fmt.Fprintln(os.Stderr, "\n - ", err) for { select { case err := <-errors: fmt.Fprintln(os.Stderr, "\n - ", err) continue default: fmt.Fprintln(os.Stderr, "\n[TestRunner] Exiting with exit code 1") os.Exit(1) } } default: return } } type junitTestSuite struct { mx sync.Mutex XMLName xml.Name `xml:"testsuite"` Name string `xml:"name,attr"` Tests int `xml:"tests,attr"` Failures int `xml:"failures,attr"` Errors int `xml:"errors,attr"` Disabled int `xml:"disabled,attr"` Skipped int `xml:"skipped,attr"` Time float64 `xml:"time,attr"` TestCase []*junitTestCase `xml:"testcase"` } type junitTestCase struct { Classname string `xml:"classname,attr"` ID string `xml:"id,attr"` Name string `xml:"name,attr"` Time float64 `xml:"time,attr"` Skipped *junitSkipped `xml:"skipped,omitempty"` Failure *junitFailure `xml:"failure,omitempty"` SystemOut string `xml:"system-out,omitempty"` } type junitSkipped struct { Message string `xml:"message,attr"` } type junitFailure struct { FailMessage string `xml:",chardata"` FailType string `xml:"type,attr"` } type test struct { name string testCase *TestCase } func getCommonInstanceMetadata(client daisyCompute.Client, project string) (*compute.Metadata, error) { proj, err := client.GetProject(project) if err != nil { return nil, fmt.Errorf("error getting project: %v", err) } return proj.CommonInstanceMetadata, nil } func delItem(items []*compute.MetadataItems, i int) []*compute.MetadataItems { // Delete the element. // https://github.com/golang/go/wiki/SliceTricks copy(items[i:], items[i+1:]) items[len(items)-1] = nil return items[:len(items)-1] } func isExpired(val string) bool { t, err := time.Parse(timeFormat, val) if err != nil { return false } return time.Now().After(t) } const ( writeLock = "TestWriteLock-" readLock = "TestReadLock-" defaultTimeout = 2 * time.Hour ) func waitLock(client daisyCompute.Client, project string, prefix ...string) (*compute.Metadata, error) { var md *compute.Metadata var err error Loop: for { md, err = getCommonInstanceMetadata(client, project) if err != nil { return nil, err } for i, mdi := range md.Items { if mdi != nil { for _, p := range prefix { if strings.HasPrefix(mdi.Key, p) { if isExpired(*mdi.Value) { md.Items = delItem(md.Items, i) } else { r := rand.Intn(10) + 5 time.Sleep(time.Duration(r) * time.Second) continue Loop } } } } } return md, nil } } func projectReadLock(client daisyCompute.Client, project, key string, timeout time.Duration) (string, error) { md, err := waitLock(client, project, writeLock) if err != nil { return "", err } lock := readLock + key val := time.Now().Add(timeout).Format(timeFormat) md.Items = append(md.Items, &compute.MetadataItems{Key: lock, Value: &val}) if err := client.SetCommonInstanceMetadata(project, md); err != nil { return "", err } return lock, nil } func customProjectWriteLock(client daisyCompute.Client, project, custom, key string, timeout time.Duration) (string, error) { customLock := readLock + custom md, err := waitLock(client, project, writeLock, customLock) if err != nil { return "", err } lock := customLock + key val := time.Now().Add(timeout).Format(timeFormat) md.Items = append(md.Items, &compute.MetadataItems{Key: lock, Value: &val}) if err := client.SetCommonInstanceMetadata(project, md); err != nil { return "", err } return lock, nil } func projectWriteLock(client daisyCompute.Client, project, key string, timeout time.Duration) (string, error) { md, err := waitLock(client, project, writeLock) if err != nil { return "", err } // This means the project has no current write locks, set the write lock // now and then wait till all current read locks are gone. lock := writeLock + key val := time.Now().Add(timeout).Format(timeFormat) md.Items = append(md.Items, &compute.MetadataItems{Key: lock, Value: &val}) if err := client.SetCommonInstanceMetadata(project, md); err != nil { return "", err } if _, err := waitLock(client, project, readLock); err != nil { // Attempt to unlock. projectUnlock(client, project, lock) return "", err } return lock, nil } func projectUnlock(client daisyCompute.Client, project, lock string) error { md, err := getCommonInstanceMetadata(client, project) if err != nil { return err } for i, mdi := range md.Items { if mdi != nil && lock == mdi.Key { md.Items = delItem(md.Items, i) } } return client.SetCommonInstanceMetadata(project, md) } var allowedChars = regexp.MustCompile("[^-_a-zA-Z0-9]+") func runTestCase(ctx context.Context, test *test, tc *junitTestCase, errors chan error, retries int) { if err := test.testCase.w.PopulateClients(ctx); err != nil { errors <- fmt.Errorf("%s: %v", tc.Name, err) tc.Failure = &junitFailure{FailMessage: err.Error(), FailType: "Error"} return } c := make(chan os.Signal, 1) signal.Notify(c, os.Interrupt) go func() { select { case <-c: fmt.Printf("\nCtrl-C caught, sending cancel signal to %q...\n", test.name) test.testCase.w.CancelWorkflow() err := fmt.Errorf("test case %q was canceled", test.name) errors <- err tc.Failure = &junitFailure{FailMessage: err.Error(), FailType: "Canceled"} case <-test.testCase.w.Cancel: } }() project := test.testCase.w.Project client := test.testCase.w.ComputeClient key := test.testCase.w.ID() var lock string var err error if test.testCase.CustomProjectLock != "" { for i := 0; i < retries; i++ { lock, err = customProjectWriteLock(client, project, allowedChars.ReplaceAllString(test.testCase.CustomProjectLock, "_"), key, test.testCase.timeout) if err == nil { break } } if err != nil { errors <- err return } } else if test.testCase.ProjectLock { for i := 0; i < retries; i++ { lock, err = projectWriteLock(client, project, key, test.testCase.timeout) if err == nil { break } } if err != nil { errors <- err return } } else { for i := 0; i < retries; i++ { lock, err = projectReadLock(client, project, key, test.testCase.timeout) if err == nil { break } } if err != nil { errors <- err return } } defer func() { for i := 0; i < retries; i++ { err := projectUnlock(client, project, lock) if err == nil { break } } if err != nil { fmt.Printf("[TestRunner] Test %q: Error unlocking project: %v\n", test.name, err) } }() select { case <-test.testCase.w.Cancel: return default: } start := time.Now() fmt.Printf("[TestRunner] Running test case %q\n", tc.Name) if err := test.testCase.w.Run(ctx); err != nil { errors <- fmt.Errorf("%s: %v", tc.Name, err) tc.Failure = &junitFailure{FailMessage: err.Error(), FailType: "Failure"} } tc.Time = time.Since(start).Seconds() tc.SystemOut = test.testCase.logger.buf.String() fmt.Printf("[TestRunner] Test case %q finished\n", tc.Name) } func main() { addFlags(os.Args[1:]) flag.Parse() varMap := map[string]string{} flag.Visit(func(flg *flag.Flag) { if strings.HasPrefix(flg.Name, varFlagPrefix) { varMap[strings.TrimPrefix(flg.Name, varFlagPrefix)] = flg.Value.String() } }) if len(flag.Args()) == 0 { fmt.Println("Not enough args, first arg needs to be the path to a test template.") os.Exit(1) } var regex *regexp.Regexp if *filter != "" { var err error regex, err = regexp.Compile(*filter) if err != nil { fmt.Println("-filter flag not valid:", err) os.Exit(1) } } ctx := context.Background() ts, err := createTestSuite(ctx, flag.Arg(0), varMap, regex) if err != nil { log.Fatalln("test case creation error:", err) } if ts == nil { return } errors := make(chan error, len(ts.Tests)) // Retry failed locks 2x as many tests in the test case. retries := len(ts.Tests) * 2 if len(ts.Tests) == 0 { fmt.Println("[TestRunner] Nothing to do") return } if *print { for n, t := range ts.Tests { if t.w == nil { continue } fmt.Printf("[TestRunner] Printing test case %q\n", n) t.w.Print(ctx) } checkError(errors) return } if *validate { for n, t := range ts.Tests { if t.w == nil { continue } fmt.Printf("[TestRunner] Validating test case %q\n", n) if err := t.w.Validate(ctx); err != nil { errors <- fmt.Errorf("Error validating test case %s: %v", n, err) } } checkError(errors) return } if err := os.MkdirAll(filepath.Dir(*outPath), 0770); err != nil { log.Fatal(err) } junit := &junitTestSuite{Name: ts.Name, Tests: len(ts.Tests)} tests := make(chan *test, len(ts.Tests)) var wg sync.WaitGroup for i := 0; i < ts.TestParallelCount; i++ { wg.Add(1) go func() { defer wg.Done() for test := range tests { tc := &junitTestCase{Classname: ts.Name, ID: test.testCase.id, Name: test.name} junit.mx.Lock() junit.TestCase = append(junit.TestCase, tc) junit.mx.Unlock() if test.testCase.w == nil { junit.mx.Lock() junit.Skipped++ junit.mx.Unlock() tc.Skipped = &junitSkipped{Message: fmt.Sprintf("Test does not match filter: %q", regex.String())} continue } runTestCase(ctx, test, tc, errors, retries) } }() } start := time.Now() for n, t := range ts.Tests { tests <- &test{name: n, testCase: t} } close(tests) wg.Wait() fmt.Printf("[TestRunner] Creating junit xml file: %q\n", *outPath) junit.Time = time.Since(start).Seconds() d, err := xml.MarshalIndent(junit, " ", " ") if err != nil { log.Fatal(err) } if err := ioutil.WriteFile(*outPath, d, 0644); err != nil { log.Fatal(err) } checkError(errors) fmt.Println("[TestRunner] All test cases completed successfully.") }