google_metadata_script_runner/main.go (376 lines of code) (raw):

// Copyright 2017 Google LLC // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // https://www.apache.org/licenses/LICENSE-2.0 // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // GCEMetadataScripts handles the running of metadata scripts on Google Compute // Engine instances. package main // TODO: compare log outputs in this utility to linux. import ( "bufio" "context" "encoding/json" "fmt" "io" "net" "net/http" "net/url" "os" "os/exec" "path" "path/filepath" "regexp" "runtime" "strings" "time" "cloud.google.com/go/storage" "github.com/GoogleCloudPlatform/guest-agent/google_guest_agent/cfg" "github.com/GoogleCloudPlatform/guest-agent/metadata" "github.com/GoogleCloudPlatform/guest-agent/retry" "github.com/GoogleCloudPlatform/guest-agent/utils" "github.com/GoogleCloudPlatform/guest-logging-go/logger" ) const ( storageURL = "storage.googleapis.com" bucket = "([a-z0-9][-_.a-z0-9]*)" object = "(.+)" defaultTimeout = 20 * time.Second ) var ( programName = path.Base(os.Args[0]) powerShellArgs = []string{"-NoProfile", "-NoLogo", "-ExecutionPolicy", "Unrestricted", "-File"} errUsage = fmt.Errorf("no valid arguments specified. Specify one of \"startup\", \"shutdown\" or \"specialize\"") // Many of the Google Storage URLs are supported below. // It is preferred that customers specify their object using // its gs://<bucket>/<object> URL. gsRegex = regexp.MustCompile(fmt.Sprintf(`^gs://%s/%s$`, bucket, object)) // Check for the Google Storage URLs: // http://<bucket>.storage.googleapis.com/<object> // https://<bucket>.storage.googleapis.com/<object> gsHTTPRegex1 = regexp.MustCompile(fmt.Sprintf(`^http[s]?://%s\.storage\.googleapis\.com/%s$`, bucket, object)) // http://storage.cloud.google.com/<bucket>/<object> // https://storage.cloud.google.com/<bucket>/<object> gsHTTPRegex2 = regexp.MustCompile(fmt.Sprintf(`^http[s]?://storage\.cloud\.google\.com/%s/%s$`, bucket, object)) // Check for the other possible Google Storage URLs: // http://storage.googleapis.com/<bucket>/<object> // https://storage.googleapis.com/<bucket>/<object> // // The following are deprecated but also checked: // http://commondatastorage.googleapis.com/<bucket>/<object> // https://commondatastorage.googleapis.com/<bucket>/<object> gsHTTPRegex3 = regexp.MustCompile(fmt.Sprintf(`^http[s]?://(?:commondata)?storage\.googleapis\.com/%s/%s$`, bucket, object)) // testStorageClient is used to override GCS client in unit tests. testStorageClient *storage.Client client metadata.MDSClientInterface version string // defaultRetryPolicy is default policy to retry up to 3 times, only wait 1 second between retries. defaultRetryPolicy = retry.Policy{MaxAttempts: 3, BackoffFactor: 1, Jitter: time.Second} ) func init() { client = metadata.New() } func newStorageClient(ctx context.Context) (*storage.Client, error) { if testStorageClient != nil { return testStorageClient, nil } return storage.NewClient(ctx) } func downloadGSURL(ctx context.Context, bucket, object string, file *os.File) error { client, err := newStorageClient(ctx) if err != nil { return fmt.Errorf("failed to create storage client: %v", err) } defer client.Close() r, err := retry.RunWithResponse(ctx, defaultRetryPolicy, func() (*storage.Reader, error) { r, err := client.Bucket(bucket).Object(object).NewReader(ctx) return r, err }) if err != nil { return err } defer r.Close() _, err = io.Copy(file, r) return err } func downloadURL(ctx context.Context, url string, file *os.File) error { res, err := retry.RunWithResponse(ctx, defaultRetryPolicy, func() (*http.Response, error) { res, err := http.Get(url) if err != nil { return res, err } if res.StatusCode != http.StatusOK { return nil, fmt.Errorf("GET %q, bad status: %s", url, res.Status) } return res, nil }) if err != nil { return err } defer res.Body.Close() _, err = io.Copy(file, res.Body) return err } func downloadScript(ctx context.Context, path string, file *os.File) error { // Startup scripts may run before DNS is running on some systems, // particularly once a system is promoted to a domain controller. // Try to lookup storage.googleapis.com and sleep for up to 100s if // we get an error. policy := retry.Policy{MaxAttempts: 20, BackoffFactor: 1, Jitter: time.Second * 5} err := retry.Run(ctx, policy, func() error { _, err := net.LookupHost(storageURL) return err }) if err != nil { return fmt.Errorf("%q lookup failed, err: %+v", storageURL, err) } bucket, object := parseGCS(path) if bucket != "" && object != "" { err = downloadGSURL(ctx, bucket, object, file) if err == nil { logger.Debugf("Succesfull download using GSURL, bucket: %s, object: %s, file: %+v", bucket, object, file) return nil } logger.Infof("Failed to download object [%s] from GCS bucket [%s], err: %+v", object, bucket, err) logger.Infof("Trying unauthenticated download") path = fmt.Sprintf("https://%s/%s/%s", storageURL, bucket, object) } // Fall back to an HTTP GET of the URL. return downloadURL(ctx, path, file) } func parseGCS(path string) (string, string) { for _, re := range []*regexp.Regexp{gsRegex, gsHTTPRegex1, gsHTTPRegex2, gsHTTPRegex3} { match := re.FindStringSubmatch(path) if len(match) == 3 { return match[1], match[2] } } return "", "" } func getMetadataKey(ctx context.Context, key string) (string, error) { md, err := getMetadata(ctx, key, false) if err != nil { return "", err } return string(md), nil } func getMetadataAttributes(ctx context.Context, key string) (map[string]string, error) { md, err := getMetadata(ctx, key, true) if err != nil { return nil, err } var att map[string]string return att, json.Unmarshal(md, &att) } func getMetadata(ctx context.Context, key string, recurse bool) ([]byte, error) { var resp string var err error if recurse { resp, err = client.GetKeyRecursive(ctx, key) } else { resp, err = client.GetKey(ctx, key, nil) } if err != nil { return nil, fmt.Errorf("unable to get %q from MDS, with recursive flag set to %t: %w", key, recurse, err) } return []byte(resp), nil } func normalizeFilePathForWindows(filePath string, metadataKey string, gcsScriptURL *url.URL) string { // If either the metadataKey ends in one of these extensions OR if this is a url startup script and if the // url path ends in one of these extensions, append the extension to the filePath name so that Windows can recognize it. for _, ext := range []string{"bat", "cmd", "ps1", "exe"} { if strings.HasSuffix(metadataKey, "-"+ext) || (gcsScriptURL != nil && strings.HasSuffix(gcsScriptURL.Path, "."+ext)) { filePath = fmt.Sprintf("%s.%s", filePath, ext) break } } return filePath } func writeScriptToFile(ctx context.Context, value string, filePath string, gcsScriptURL *url.URL) error { // Create or download files. if gcsScriptURL != nil { file, err := os.OpenFile(filePath, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0755) if err != nil { return fmt.Errorf("error opening temp file: %v", err) } if err := downloadScript(ctx, value, file); err != nil { file.Close() return err } if err := file.Close(); err != nil { return fmt.Errorf("error closing temp file: %v", err) } } else { // Trim leading spaces and newlines. value = strings.TrimLeft(value, " \n\v\f\t\r") if err := os.WriteFile(filePath, []byte(value), 0755); err != nil { return fmt.Errorf("error writing temp file: %v", err) } } return nil } func setupAndRunScript(ctx context.Context, metadataKey string, value string) error { // Make sure that the URL is valid for URL startup scripts var gcsScriptURL *url.URL if strings.HasSuffix(metadataKey, "-url") { var err error gcsScriptURL, err = url.Parse(strings.TrimSpace(value)) if err != nil { return err } } // Make temp directory. tmpDir, err := os.MkdirTemp(cfg.Get().MetadataScripts.RunDir, "metadata-scripts") if err != nil { return err } defer os.RemoveAll(tmpDir) tmpFile := filepath.Join(tmpDir, metadataKey) if runtime.GOOS == "windows" { tmpFile = normalizeFilePathForWindows(tmpFile, metadataKey, gcsScriptURL) } if err := writeScriptToFile(ctx, value, tmpFile, gcsScriptURL); err != nil { return fmt.Errorf("unable to write script to file: %v", err) } return runScript(tmpFile, metadataKey) } // Craft the command to run. func runScript(filePath string, metadataKey string) error { var cmd *exec.Cmd if strings.HasSuffix(filePath, ".ps1") { cmd = exec.Command("powershell.exe", append(powerShellArgs, filePath)...) } else { if runtime.GOOS == "windows" { cmd = exec.Command(filePath) } else { cmd = exec.Command(cfg.Get().MetadataScripts.DefaultShell, "-c", filePath) } } return runCmd(cmd, metadataKey) } func runCmd(c *exec.Cmd, name string) error { pr, pw, err := os.Pipe() if err != nil { return err } defer pr.Close() c.Stdout = pw c.Stderr = pw if err := c.Start(); err != nil { return err } pw.Close() in := bufio.NewScanner(pr) for { if !in.Scan() { if err := in.Err(); err != nil { logger.Errorf("error while communicating with %q script: %v", name, err) } break } logger.Log(logger.LogEntry{ Message: fmt.Sprintf("%s: %s", name, in.Text()), CallDepth: 3, Severity: logger.Info, }) } pr.Close() return c.Wait() } // getWantedKeys returns the list of keys to check for a given type of script and OS. func getWantedKeys(args []string, os string) ([]string, error) { if len(args) != 2 { return nil, errUsage } prefix := args[1] switch prefix { case "specialize": prefix = "sysprep-specialize" case "startup": if os == "windows" { prefix = "windows-" + prefix if !cfg.Get().MetadataScripts.StartupWindows { return nil, fmt.Errorf("windows startup scripts disabled in instance config") } } else { if !cfg.Get().MetadataScripts.Startup { return nil, fmt.Errorf("startup scripts disabled in instance config") } } case "shutdown": if os == "windows" { prefix = "windows-" + prefix if !cfg.Get().MetadataScripts.ShutdownWindows { return nil, fmt.Errorf("windows shutdown scripts disabled in instance config") } } else { if !cfg.Get().MetadataScripts.Shutdown { return nil, fmt.Errorf("shutdown scripts disabled in instance config") } } default: return nil, errUsage } var mdkeys []string var suffixes []string if os == "windows" { suffixes = []string{"ps1", "cmd", "bat", "url"} } else { suffixes = []string{"url"} // The 'bare' startup-script or shutdown-script key, not supported on Windows. mdkeys = append(mdkeys, fmt.Sprintf("%s-script", prefix)) } for _, suffix := range suffixes { mdkeys = append(mdkeys, fmt.Sprintf("%s-script-%s", prefix, suffix)) } return mdkeys, nil } func parseMetadata(md map[string]string, wanted []string) map[string]string { found := make(map[string]string) for _, key := range wanted { val, ok := md[key] if !ok || val == "" { continue } found[key] = val } return found } // getExistingKeys returns the wanted keys that are set in metadata. func getExistingKeys(ctx context.Context, wanted []string) (map[string]string, error) { for _, attrs := range []string{"/instance/attributes", "/project/attributes"} { md, err := getMetadataAttributes(ctx, attrs) if err != nil { return nil, err } if found := parseMetadata(md, wanted); len(found) != 0 { return found, nil } } return nil, nil } func logFormatWindows(e logger.LogEntry) string { now := time.Now().Format("2006/01/02 15:04:05") // 2006/01/02 15:04:05 GCEMetadataScripts This is a log message. return fmt.Sprintf("%s %s: %s", now, programName, e.Message) } func main() { ctx := context.Background() opts := logger.LogOpts{LoggerName: programName} if runtime.GOOS == "windows" { opts.Writers = []io.Writer{&utils.SerialPort{Port: "COM1"}, os.Stdout} opts.FormatFunction = logFormatWindows } else { opts.Writers = []io.Writer{os.Stdout} opts.FormatFunction = func(e logger.LogEntry) string { return e.Message } // Local logging is syslog; we will just use stdout in Linux. opts.DisableLocalLogging = true } var err error if err := cfg.Load(nil); err != nil { fmt.Fprintf(os.Stderr, "Failed to load instance configuration: %+v", err) os.Exit(1) } if !cfg.Get().Core.CloudLoggingEnabled { opts.DisableCloudLogging = true } // The keys to check vary based on the argument and the OS. Also functions to validate arguments. wantedKeys, err := getWantedKeys(os.Args, runtime.GOOS) if err != nil { fmt.Printf("%s\n", err.Error()) os.Exit(2) } projectID, err := getMetadataKey(ctx, "/project/project-id") if err == nil { opts.ProjectName = projectID } createdBy, err := getMetadataKey(ctx, "/instance/attributes/created-by") if err == nil { opts.MIG = createdBy } if err := logger.Init(ctx, opts); err != nil { fmt.Printf("Error initializing logger: %+v", err) os.Exit(1) } // Try flushing logs before exiting, if not flushed logs could go missing. defer logger.Close() logger.Infof("Starting %s scripts (version %s).", os.Args[1], version) scripts, err := getExistingKeys(ctx, wantedKeys) if err != nil { logger.Fatalf(err.Error()) } if len(scripts) == 0 { logger.Infof("No %s scripts to run.", os.Args[1]) return } for _, wantedKey := range wantedKeys { value, ok := scripts[wantedKey] if !ok { continue } logger.Infof("Found %s in metadata.", wantedKey) if err := setupAndRunScript(ctx, wantedKey, value); err != nil { logger.Warningf("Script %q failed with error: %v", wantedKey, err) continue } logger.Infof("%s exit status 0", wantedKey) } logger.Infof("Finished running %s scripts.", os.Args[1]) }