cmd/gitaly-wrapper/main.go (145 lines of code) (raw):

package main import ( "errors" "fmt" "os" "os/exec" "os/signal" "path/filepath" "strconv" "syscall" "time" "gitlab.com/gitlab-org/gitaly/v16/internal/bootstrap" "gitlab.com/gitlab-org/gitaly/v16/internal/helper/env" "gitlab.com/gitlab-org/gitaly/v16/internal/log" "gitlab.com/gitlab-org/gitaly/v16/internal/ps" "golang.org/x/sys/unix" ) const ( envJSONLogging = "WRAPPER_JSON_LOGGING" ) func main() { logFormat := "text" if jsonLogging() { logFormat = "json" } logger, err := log.Configure(os.Stdout, logFormat, "info") if err != nil { fmt.Printf("configuring logger failed: %v", err) os.Exit(1) } logger = logger.WithField("wrapper", os.Getpid()) if len(os.Args) < 2 { logger.Error(fmt.Sprintf("usage: %s forking_binary [args]", os.Args[0])) os.Exit(1) } binary, arguments := os.Args[1], os.Args[2:] logger.Info("Wrapper started") pidFilePath := os.Getenv(bootstrap.EnvPidFile) if pidFilePath == "" { logger.Error(fmt.Sprintf("missing pid file ENV variable %q", bootstrap.EnvPidFile)) os.Exit(1) } logger.WithField("pid_file", pidFilePath).Info("finding process") process, err := findProcess(pidFilePath) if err != nil && !isRecoverable(err) { logger.WithError(err).Error("find process") os.Exit(1) } else if err != nil { logger.WithError(err).Error("find process") } if process != nil && isExpectedProcess(process, binary) { logger.Info("adopting a process") } else { logger.Info("spawning a process") proc, err := spawnProcess(logger, binary, arguments) if err != nil { logger.WithError(err).Error("spawn gitaly") os.Exit(1) } process = proc } logger = logger.WithField("process", process.Pid) logger.Info("monitoring process") forwardSignals(process, logger) // wait for isProcessAlive(process) { time.Sleep(1 * time.Second) } logger.Error("wrapper for process shutting down") } func isRecoverable(err error) bool { var numError *strconv.NumError return os.IsNotExist(err) || errors.As(err, &numError) } func findProcess(pidFilePath string) (*os.Process, error) { pid, err := readPIDFile(pidFilePath) if err != nil { return nil, err } // os.FindProcess on unix do not return an error if the process does not exist process, err := os.FindProcess(pid) if err != nil { return nil, err } if isProcessAlive(process) { return process, nil } return nil, nil } func spawnProcess(logger log.Logger, bin string, args []string) (*os.Process, error) { cmd := exec.Command(bin, args...) cmd.Env = append(os.Environ(), fmt.Sprintf("%s=true", bootstrap.EnvUpgradesEnabled)) cmd.Stdin = os.Stdin cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr if err := cmd.Start(); err != nil { return nil, err } // This cmd.Wait() is crucial. Without it we cannot detect if the command we just spawned has crashed. go func() { if err := cmd.Wait(); err != nil { logger.WithError(err).Error("waiting for supervised command") } }() return cmd.Process, nil } func isRuntimeSig(s os.Signal) bool { return s == unix.SIGURG } func forwardSignals(gitaly *os.Process, log log.Logger) { sigs := make(chan os.Signal, 1) go func() { for sig := range sigs { // In go1.14+, the go runtime issues SIGURG as an interrupt // to support pre-emptible system calls on Linux. We ignore // this signal since it's not relevant to the Gitaly process. if isRuntimeSig(sig) { continue } log.WithField("signal", sig).Warn("forwarding signal") if err := gitaly.Signal(sig); err != nil { log.WithField("signal", sig).WithError(err).Error("can't forward the signal") } } }() signal.Notify(sigs) } func readPIDFile(pidFilePath string) (int, error) { data, err := os.ReadFile(pidFilePath) if err != nil { return 0, err } return strconv.Atoi(string(data)) } func isProcessAlive(p *os.Process) bool { // After p exits, and after it gets reaped, this p.Signal will fail. It is crucial that p gets reaped. // If p was spawned by the current process, it will get reaped from a goroutine that does cmd.Wait(). // If p was spawned by someone else we rely on them to reap it, or on p to become an orphan. // In the orphan case p should get reaped by the OS (PID 1). return p.Signal(syscall.Signal(0)) == nil } func isExpectedProcess(p *os.Process, binary string) bool { command, err := ps.Comm(p.Pid) if err != nil { return false } if filepath.Base(command) == filepath.Base(binary) { return true } return false } func jsonLogging() bool { enabled, _ := env.GetBool(envJSONLogging, false) return enabled }