build/bazel/fix_grpc_gateway.go (59 lines of code) (raw):

// // Copyright 2020 Google LLC // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. // // fix_grpc_gateway applies patches to the file generated by gen-grpc-gateway. // // The go file generated by gen-grpc-gateway is expected to be placed in the // same package as the go library generated from the proto. For googleapis, we // do not control the BUILD rule for how that library is generated (or used). // // Instead, we modify the gateway file after it is generated so that we can // use it as a standalone package. package main import ( "errors" "flag" "fmt" "io/ioutil" "os" "regexp" "strings" ) var ( inputGatewayFile = flag.String("input_gateway_file", "", "Original gateway file .") outputGatewayFile = flag.String("output_gateway_file", "", "Modified gateway file.") protoImports = flag.String("proto_imports", "", "Extra imports required by the gateway.") ) func validateFlags() error { if *inputGatewayFile == "" { return errors.New("--input_gateway_file must be specified") } if *outputGatewayFile == "" { return errors.New("--output_gateway_file must be specified") } if *protoImports == "" { return errors.New("--proto_imports must be specified") } return nil } func run() error { // Read-in the contents of the gateway file generated by gen-grpc-gateway. bytes, err := ioutil.ReadFile(*inputGatewayFile) if err != nil { return fmt.Errorf("Could not read input file %s: %s", *inputGatewayFile, err) } contents := string(bytes) // Dot-import proto packages from which this gateway was generated. extraImports := []string{} for _, importName := range strings.Split(*protoImports, ",") { extraImports = append(extraImports, ". \""+importName+"\"") } // Add a new import section after the package statement. loc := regexp.MustCompile(`package [^\n]*\n`).FindStringIndex(contents) contents = contents[:loc[1]] + "\nimport (\n\t" + strings.Join(extraImports, "\n\t") + "\n)\n\n" + contents[loc[1]:] // Output the modified gateway file. err = ioutil.WriteFile(*outputGatewayFile, []byte(contents), 0644) if err != nil { return err } return nil } func main() { flag.Parse() if err := validateFlags(); err != nil { fmt.Fprintf(os.Stderr, "Invalid flags: %s\n", err) flag.Usage() os.Exit(1) } if err := run(); err != nil { fmt.Fprintf(os.Stderr, "%s\n", err) os.Exit(1) } }