infra/module-swapper/cmd/swap.go (374 lines of code) (raw):

package cmd import ( "fmt" "log" "os" "path/filepath" "strings" "github.com/go-git/go-git/v5" "github.com/hashicorp/hcl/v2" "github.com/hashicorp/hcl/v2/gohcl" "github.com/hashicorp/hcl/v2/hclparse" "github.com/hashicorp/hcl/v2/hclwrite" "github.com/pmezard/go-difflib/difflib" "github.com/zclconf/go-cty/cty" giturl "github.com/chainguard-dev/git-urls" ) type LocalTerraformModule struct { Name string Dir string ModuleFQN string } const ( moduleBlockType = "module" sourceAttrib = "source" terraformExtension = "*.tf" restoreMarker = "[restore-marker]" linebreak = "\n" ) var ( localModules = []LocalTerraformModule{} ) // getRemoteURL gets the URL of a given remote from git repo at dir func getRemoteURL(dir, remoteName string) (string, error) { r, err := git.PlainOpen(dir) if err != nil { return "", err } rm, err := r.Remote(remoteName) if err != nil { return "", err } return rm.Config().URLs[0], nil } // trimAnySuffixes trims first matching suffix from slice of suffixes func trimAnySuffixes(s string, suffixes []string) string { for _, suffix := range suffixes { if strings.HasSuffix(s, suffix) { s = s[:len(s)-len(suffix)] return s } } return s } // getModuleNameRegistry returns module name and registry by parsing git remote func getModuleNameRegistry(dir string) (string, string, error) { remote, err := getRemoteURL(dir, "origin") if err != nil { return "", "", err } u, err := giturl.Parse(remote) if err != nil { return "", "", err } if u.Host != "github.com" { return "", "", fmt.Errorf("expected GitHub remote, got: %s", remote) } orgRepo := u.Path orgRepo = trimAnySuffixes(orgRepo, []string{"/", ".git"}) orgRepo = strings.TrimPrefix(orgRepo, "/") split := strings.Split(orgRepo, "/") if len(split) != 2 { return "", "", fmt.Errorf("expected GitHub remote of form https://github.com/ModuleRegistry/ModuleRepo, got: %s", remote) } org, repoName := split[0], split[1] // module repos are prefixed with terraform-google- if !strings.HasPrefix(repoName, "terraform-google-") { return "", "", fmt.Errorf("expected to find repo name prefixed with terraform-google-. Got: %s", repoName) } moduleName := strings.ReplaceAll(repoName, "terraform-google-", "") log.Printf("Module name set from remote to %s", moduleName) return moduleName, org, nil } // findSubModules generates slice of LocalTerraformModule for submodules func findSubModules(path, rootModuleFQN string) []LocalTerraformModule { var subModules = make([]LocalTerraformModule, 0) // if no modules dir, return empty slice if _, err := os.Stat(path); err != nil { log.Print("No submodules found") return subModules } files, err := os.ReadDir(path) if err != nil { log.Fatalf("Error finding submodules: %v", err) } absPath, err := filepath.Abs(path) if err != nil { log.Fatalf("Error finding submodule absolute path: %v", err) } for _, f := range files { if f.IsDir() { subModules = append(subModules, LocalTerraformModule{f.Name(), filepath.Join(absPath, f.Name()), fmt.Sprintf("%s//modules/%s", rootModuleFQN, f.Name())}) } } return subModules } // restoreModules restores old config as marked by restoreMarker func restoreModules(f []byte, p string) ([]byte, error) { if _, err := os.Stat(p); err != nil { return nil, err } strFile := string(f) if !strings.Contains(strFile, restoreMarker) { return f, nil } lines := strings.Split(strFile, linebreak) for i, line := range lines { if strings.Contains(line, restoreMarker) { lines[i] = strings.Split(line, restoreMarker)[1] } } return []byte(strings.Join(lines, linebreak)), nil } // matchedModule returns matching local TF module based on local path. func matchedModule(localPath string) *LocalTerraformModule { for _, l := range localModules { if localPath == l.Dir { return &l } } return nil } // localToRemote converts all local references in f to remote references. func localToRemote(f []byte, p string) ([]byte, error) { if _, err := os.Stat(p); err != nil { return nil, err } absPath, err := filepath.Abs(filepath.Dir(p)) if err != nil { return nil, fmt.Errorf("failed to get absolute path: %v", err) } f, err = restoreModules(f, p) if err != nil { return nil, err } currentReferences, err := moduleSourceRefs(f, p) if err != nil { return nil, fmt.Errorf("failed to write find module sources: %v", err) } newReferences := map[string]string{} for label, source := range currentReferences { localModule := matchedModule(filepath.Clean(filepath.Join(absPath, source))) if localModule == nil { log.Printf("no matches for %s", source) continue } newReferences[label] = localModule.ModuleFQN } if len(currentReferences) == 0 { return f, nil } updated, err := writeModuleRefs(f, p, newReferences) if err != nil { return nil, fmt.Errorf("failed to write updated module sources: %v", err) } // print diff info log.Printf("Modifications made to file %s", p) diff := difflib.UnifiedDiff{ A: difflib.SplitLines(string(f)), B: difflib.SplitLines(string(updated)), FromFile: "Original", ToFile: "Modified", Context: 3, } diffInfo, _ := difflib.GetUnifiedDiffString(diff) log.Println(diffInfo) return updated, nil } // remoteToLocal converts all remote references in f to local references. func remoteToLocal(f []byte, p string) ([]byte, error) { if _, err := os.Stat(p); err != nil { return nil, err } f = commentVersions(f) absPath, err := filepath.Abs(filepath.Dir(p)) if err != nil { return nil, fmt.Errorf("failed to get absolute path: %v", err) } fqnMap := make(map[string]LocalTerraformModule, len(localModules)) for _, l := range localModules { fqnMap[l.ModuleFQN] = l } currentReferences, err := moduleSourceRefs(f, p) if err != nil { return nil, fmt.Errorf("failed to write find module sources: %v", err) } newReferences := map[string]string{} for label, source := range currentReferences { localModule, exists := fqnMap[source] if !exists { continue } newModulePath, err := filepath.Rel(absPath, localModule.Dir) if err != nil { return nil, fmt.Errorf("failed to find relative path: %v", err) } newReferences[label] = newModulePath } if len(currentReferences) == 0 { return f, nil } updated, err := writeModuleRefs(f, p, newReferences) if err != nil { return nil, fmt.Errorf("failed to write updated module sources: %v", err) } // print diff info log.Printf("Modifications made to file %s", p) diff := difflib.UnifiedDiff{ A: difflib.SplitLines(string(f)), B: difflib.SplitLines(string(updated)), FromFile: "Original", ToFile: "Modified", Context: 3, } diffInfo, _ := difflib.GetUnifiedDiffString(diff) log.Println(diffInfo) return updated, nil } // commentVersions comments version attributes for local modules. func commentVersions(f []byte) []byte { strFile := string(f) lines := strings.Split(strFile, linebreak) for _, localModule := range localModules { // check if current file has module/submodules references that should be swapped if !strings.Contains(strFile, localModule.ModuleFQN) { continue } for i, line := range lines { if !strings.Contains(line, localModule.ModuleFQN) { continue } if i < len(lines)-1 && strings.Contains(lines[i+1], "version") && !strings.Contains(lines[i+1], restoreMarker) { leadingWhiteSpace := lines[i+1][:strings.Index(lines[i+1], "version")] lines[i+1] = fmt.Sprintf("%s# %s %s", leadingWhiteSpace, restoreMarker, lines[i+1]) } } } newExample := strings.Join(lines, linebreak) return []byte(newExample) } // getTFFiles returns a slice of valid TF file paths func getTFFiles(path string) []string { // validate path if _, err := os.Stat(path); err != nil { log.Fatal(fmt.Errorf("Unable to find %s : %v", path, err)) } var files = make([]string, 0) err := filepath.Walk(path, func(path string, info os.FileInfo, err error) error { if err != nil && info.IsDir() { return nil } isTFFile, _ := filepath.Match(terraformExtension, filepath.Base(path)) if isTFFile { files = append(files, path) } return nil }) if err != nil { log.Printf("Error walking files: %v", err) } return files } var ( // Partial schema of examples. exampleSchema = &hcl.BodySchema{ Blocks: []hcl.BlockHeaderSchema{ { Type: moduleBlockType, LabelNames: []string{"name"}, }, }, } // Partial schema of each module. moduleSchema = &hcl.BodySchema{ Attributes: []hcl.AttributeSchema{ { Name: sourceAttrib, }, }, } ) // moduleSourceRefs returns a map of module label to corresponding source references. func moduleSourceRefs(f []byte, TFFilePath string) (map[string]string, error) { refs := map[string]string{} p, err := hclparse.NewParser().ParseHCL(f, TFFilePath) if err != nil { return nil, fmt.Errorf("failed to parse hcl: %v", err) } c, _, diags := p.Body.PartialContent(exampleSchema) if diags.HasErrors() { return nil, fmt.Errorf("failed to parse example content: %v", diags.Error()) } for _, b := range c.Blocks { if b.Type != moduleBlockType { continue } if len(b.Labels) != 1 { log.Printf("got multiple labels %v, module should only have one", b.Labels) continue } content, _, diags := b.Body.PartialContent(moduleSchema) if diags.HasErrors() { log.Printf("skipping %s module, failed to parse module content: %v", b.Labels[0], diags.Error()) continue } sourcrAttr, exists := content.Attributes[sourceAttrib] if !exists { log.Printf("skipping %s module, no source attribute", b.Labels[0]) continue } var sourceName string diags = gohcl.DecodeExpression(sourcrAttr.Expr, nil, &sourceName) if diags.HasErrors() { log.Printf("skipping %s module, failed to decode source value: %v", b.Labels[0], diags.Error()) continue } refs[b.Labels[0]] = sourceName } return refs, nil } // writeModuleRefs appends or overwrites provided moduleRefs to file f. func writeModuleRefs(f []byte, p string, moduleRefs map[string]string) ([]byte, error) { wf, diags := hclwrite.ParseConfig(f, p, hcl.Pos{}) if diags.HasErrors() { return nil, fmt.Errorf("failed to parse hcl: %v", diags.Error()) } for _, b := range wf.Body().Blocks() { if b.Type() != moduleBlockType { continue } if len(b.Labels()) != 1 { log.Printf("got multiple labels %v, module should only have one", b.Labels()) continue } newSource, exists := moduleRefs[b.Labels()[0]] if !exists { continue } b.Body().SetAttributeValue(sourceAttrib, cty.StringVal(newSource)) } var testS strings.Builder _, err := wf.WriteTo(&testS) if err != nil { return nil, fmt.Errorf("failed to write hcl: %v", diags.Error()) } return []byte(testS.String()), nil } func SwapModules(rootPath, moduleRegistrySuffix, moduleRegistryPrefix, subModulesDir, examplesDir string, restore bool) { rootPath = filepath.Clean(rootPath) moduleName, foundRegistryPrefix, err := getModuleNameRegistry(rootPath) if err != nil && moduleRegistryPrefix == "" { log.Printf("failed to get module name and registry: %v", err) return } if moduleRegistryPrefix != "" { foundRegistryPrefix = moduleRegistryPrefix } // add root module to slice of localModules localModules = append(localModules, LocalTerraformModule{moduleName, rootPath, fmt.Sprintf("%s/%s/%s", foundRegistryPrefix, moduleName, moduleRegistrySuffix)}) examplesPath := fmt.Sprintf("%s/%s", rootPath, examplesDir) subModulesPath := fmt.Sprintf("%s/%s", rootPath, subModulesDir) // add submodules, if any to localModules submods := findSubModules(subModulesPath, localModules[0].ModuleFQN) localModules = append(localModules, submods...) // find all TF files in examples dir to process exampleTFFiles := getTFFiles(examplesPath) for _, TFFilePath := range exampleTFFiles { file, err := os.ReadFile(TFFilePath) if err != nil { log.Printf("Error reading file: %v", err) } var newFile []byte if restore { newFile, err = localToRemote(file, TFFilePath) } else { newFile, err = remoteToLocal(file, TFFilePath) } if err != nil { log.Printf("Error processing file: %v", err) } if newFile != nil { err = os.WriteFile(TFFilePath, newFile, 0644) if err != nil { log.Printf("Error writing file: %v", err) } } } }