cli/main.go (267 lines of code) (raw):
// Copyright 2017 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Daisy is a GCE workflow tool.
package main
import (
"context"
"encoding/json"
"flag"
"fmt"
"io/ioutil"
"log"
"os"
"os/signal"
"strings"
"sync"
"time"
"cloud.google.com/go/compute/metadata"
daisy "github.com/GoogleCloudPlatform/compute-daisy"
)
var (
oauth = flag.String("oauth", "", "path to oauth json file, overrides what is set in workflow")
project = flag.String("project", "", "project to run in, overrides what is set in workflow")
gcsPath = flag.String("gcs_path", "", "GCS bucket to use, overrides what is set in workflow")
zone = flag.String("zone", "", "zone to run in, overrides what is set in workflow")
variables = flag.String("variables", "", "comma separated list of variables, in the form 'key=value'")
print = flag.Bool("print", false, "print out the parsed workflow for debugging")
printPerf = flag.Bool("print_perf", false, "print out the performance profile")
validate = flag.Bool("validate", false, "validate the workflow and exit")
format = flag.Bool("format_workflow", false, "format the workflow file(s) and exit")
defaultTimeout = flag.String("default_timeout", "", "sets the default timeout for the workflow")
ce = flag.String("compute_endpoint_override", "", "API endpoint to override default")
gcsLogsDisabled = flag.Bool("disable_gcs_logging", false, "do not stream logs to GCS")
cloudLogsDisabled = flag.Bool("disable_cloud_logging", false, "do not stream logs to Cloud Logging")
stdoutLogsDisabled = flag.Bool("disable_stdout_logging", false, "do not display individual workflow logs on stdout")
// Projects with numerous images can improve performance by getting images directly rather than
// listing and caching them, since listing becomes slower than direct retrieval in such cases.
skipCachingImages = flag.Bool("skip_caching_images", false, "do not cache images from workflow project on startup")
)
const (
flgDefValue = "flag generated for workflow variable"
varFlagPrefix = "var:"
)
func populateVars(input string) map[string]string {
varMap := map[string]string{}
if input != "" {
for _, v := range strings.Split(input, ",") {
i := strings.Index(v, "=")
if i == -1 {
continue
}
varMap[v[:i]] = v[i+1:]
}
}
flag.Visit(func(flg *flag.Flag) {
if strings.HasPrefix(flg.Name, varFlagPrefix) {
varMap[strings.TrimPrefix(flg.Name, varFlagPrefix)] = flg.Value.String()
}
})
return varMap
}
func parseWorkflow(ctx context.Context, path string, varMap map[string]string, project, zone, gcsPath, oauth, dTimeout, cEndpoint string, disableGCSLogs, diableCloudLogs, disableStdoutLogs, skipCachingImages bool) (*daisy.Workflow, error) {
w, err := daisy.NewFromFile(path)
if err != nil {
return nil, err
}
Loop:
for k, v := range varMap {
for wv := range w.Vars {
if k == wv {
w.AddVar(k, v)
continue Loop
}
}
return nil, fmt.Errorf("unknown workflow Var %q passed to Workflow %q", k, w.Name)
}
if project != "" {
w.Project = project
} else if w.Project == "" && metadata.OnGCE() {
w.Project, err = metadata.ProjectID()
if err != nil {
return nil, fmt.Errorf("Failed to get GCE project id from metadata: %v", err)
}
}
if zone != "" {
w.Zone = zone
} else if w.Zone == "" && metadata.OnGCE() {
w.Zone, err = metadata.Zone()
if err != nil {
return nil, fmt.Errorf("Failed to get GCE zone from metadata: %v", err)
}
}
if gcsPath != "" {
w.GCSPath = gcsPath
}
if oauth != "" {
w.OAuthPath = oauth
}
if dTimeout != "" {
w.DefaultTimeout = dTimeout
}
if cEndpoint != "" {
w.ComputeEndpoint = cEndpoint
}
if disableGCSLogs {
w.DisableGCSLogging()
}
if diableCloudLogs {
w.DisableCloudLogging()
}
if disableStdoutLogs {
w.DisableStdoutLogging()
}
if skipCachingImages {
w.SkipCachingImages()
}
return w, nil
}
func addFlags(args []string) {
for _, arg := range args {
if len(arg) <= 1 || arg[0] != '-' {
continue
}
name := arg[1:]
if name[0] == '-' {
name = name[1:]
}
if !strings.HasPrefix(name, varFlagPrefix) {
continue
}
name = strings.SplitN(name, "=", 2)[0]
if flag.Lookup(name) != nil {
continue
}
flag.String(name, "", flgDefValue)
}
}
func fmtWorkflow(path string) error {
f, err := os.OpenFile(path, os.O_RDWR, 0)
if err != nil {
return err
}
data, err := ioutil.ReadAll(f)
if err != nil {
return err
}
var w *daisy.Workflow
if err := json.Unmarshal(data, &w); err != nil {
return daisy.JSONError(path, data, err)
}
newData, err := json.MarshalIndent(w, "", " ")
if err != nil {
return err
}
if err := f.Truncate(0); err != nil {
return err
}
if _, err := f.WriteAt(newData, 0); err != nil {
return err
}
if err := f.Close(); err != nil {
return err
}
return nil
}
func printPerfProfile(workflow *daisy.Workflow) {
timeRecords := workflow.GetStepTimeRecords()
if len(timeRecords) == 0 {
return
}
wfStartTime := time.Now()
wfEndTime := time.Time{}
fmt.Println("\nPerf Profile:")
for _, r := range timeRecords {
if wfStartTime.After(r.StartTime) {
wfStartTime = r.StartTime
}
if wfEndTime.Before(r.EndTime) {
wfEndTime = r.EndTime
}
fmt.Printf("- %v: %v\n", r.Name, formatDuration(r.EndTime.Sub(r.StartTime)))
}
fmt.Printf("Total time: %v\n\n", formatDuration(wfEndTime.Sub(wfStartTime)))
}
func formatDuration(d time.Duration) string {
s := int(d.Seconds())
return fmt.Sprintf("[hh:mm:ss] %v:%v:%v", s/3600, s/60%60, s%60)
}
func main() {
addFlags(os.Args[1:])
flag.Parse()
if len(flag.Args()) == 0 {
log.Fatal("Not enough args, first arg needs to be the path to a workflow.")
}
if *format {
for _, path := range flag.Args() {
fmt.Printf("[Daisy] Formating workflow file %q\n", path)
if err := fmtWorkflow(path); err != nil {
fmt.Print(err)
}
}
return
}
ctx := context.Background()
var ws []*daisy.Workflow
varMap := populateVars(*variables)
for _, path := range flag.Args() {
w, err := parseWorkflow(ctx, path, varMap, *project, *zone, *gcsPath, *oauth, *defaultTimeout, *ce, *gcsLogsDisabled, *cloudLogsDisabled, *stdoutLogsDisabled, *skipCachingImages)
if err != nil {
log.Fatalf("error parsing workflow %q: %v", path, err)
}
ws = append(ws, w)
}
errors := make(chan error, len(ws))
var wg sync.WaitGroup
for _, w := range ws {
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
go func(w *daisy.Workflow) {
select {
case <-c:
fmt.Printf("\nCtrl-C caught, sending cancel signal to %q...\n", w.Name)
w.CancelWorkflow()
errors <- fmt.Errorf("workflow %q was canceled", w.Name)
case <-w.Cancel:
}
}(w)
if *print {
fmt.Printf("[Daisy] Printing workflow %q\n", w.Name)
w.Print(ctx)
continue
}
if *validate {
fmt.Printf("[Daisy] Validating workflow %q\n", w.Name)
if err := w.Validate(ctx); err != nil {
fmt.Fprintf(os.Stderr, "[Daisy] Error validating workflow %q: %v\n", w.Name, err)
}
continue
}
wg.Add(1)
go func(w *daisy.Workflow) {
defer wg.Done()
if *printPerf {
defer printPerfProfile(w)
}
fmt.Printf("[Daisy] Running workflow %q (id=%s)\n", w.Name, w.ID())
if err := w.Run(ctx); err != nil {
errors <- fmt.Errorf("%s: %v", w.Name, err)
return
}
fmt.Printf("[Daisy] Workflow %q finished\n", w.Name)
}(w)
}
wg.Wait()
select {
case err := <-errors:
fmt.Fprintln(os.Stderr, "\n[Daisy] Errors in one or more workflows:")
fmt.Fprintln(os.Stderr, " ", err)
for {
select {
case err := <-errors:
fmt.Fprintln(os.Stderr, " ", err)
continue
default:
os.Exit(1)
}
}
default:
if !*print && !*validate {
fmt.Println("[Daisy] All workflows completed successfully.")
}
}
}