build/bazel/fix_grpc_gateway.go (59 lines of code) (raw):
//
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//
// fix_grpc_gateway applies patches to the file generated by gen-grpc-gateway.
//
// The go file generated by gen-grpc-gateway is expected to be placed in the
// same package as the go library generated from the proto. For googleapis, we
// do not control the BUILD rule for how that library is generated (or used).
//
// Instead, we modify the gateway file after it is generated so that we can
// use it as a standalone package.
package main
import (
"errors"
"flag"
"fmt"
"io/ioutil"
"os"
"regexp"
"strings"
)
var (
inputGatewayFile = flag.String("input_gateway_file", "", "Original gateway file .")
outputGatewayFile = flag.String("output_gateway_file", "", "Modified gateway file.")
protoImports = flag.String("proto_imports", "", "Extra imports required by the gateway.")
)
func validateFlags() error {
if *inputGatewayFile == "" {
return errors.New("--input_gateway_file must be specified")
}
if *outputGatewayFile == "" {
return errors.New("--output_gateway_file must be specified")
}
if *protoImports == "" {
return errors.New("--proto_imports must be specified")
}
return nil
}
func run() error {
// Read-in the contents of the gateway file generated by gen-grpc-gateway.
bytes, err := ioutil.ReadFile(*inputGatewayFile)
if err != nil {
return fmt.Errorf("Could not read input file %s: %s", *inputGatewayFile, err)
}
contents := string(bytes)
// Dot-import proto packages from which this gateway was generated.
extraImports := []string{}
for _, importName := range strings.Split(*protoImports, ",") {
extraImports = append(extraImports, ". \""+importName+"\"")
}
// Add a new import section after the package statement.
loc := regexp.MustCompile(`package [^\n]*\n`).FindStringIndex(contents)
contents = contents[:loc[1]] +
"\nimport (\n\t" + strings.Join(extraImports, "\n\t") + "\n)\n\n" +
contents[loc[1]:]
// Output the modified gateway file.
err = ioutil.WriteFile(*outputGatewayFile, []byte(contents), 0644)
if err != nil {
return err
}
return nil
}
func main() {
flag.Parse()
if err := validateFlags(); err != nil {
fmt.Fprintf(os.Stderr, "Invalid flags: %s\n", err)
flag.Usage()
os.Exit(1)
}
if err := run(); err != nil {
fmt.Fprintf(os.Stderr, "%s\n", err)
os.Exit(1)
}
}