cmds/exec_agent/monitor.go (156 lines of code) (raw):

// Copyright (c) Facebook, Inc. and its affiliates. // // This source code is licensed under the MIT license found in the // LICENSE file in the root directory of this source tree. package main import ( "context" "errors" "fmt" "log" "net" "net/http" "net/rpc" "os" "sync" "syscall" ) type PollResponse struct { Stdout []byte Stderr []byte Alive bool } type monitor struct { proc *os.Process // TODO: these should really be readonly stdout *SafeBuffer stderr *SafeBuffer reaper *SafeSignal // there really shouldnt be multiple concurrent consumers // by design, but better be safe than sorry mu sync.Mutex } func newMonitor(proc *os.Process, stdout *SafeBuffer, stderr *SafeBuffer, reaper *SafeSignal) *monitor { return &monitor{proc, stdout, stderr, reaper, sync.Mutex{}} } func (m *monitor) Poll(_ int, reply *PollResponse) error { log.Printf("got a call for: poll") m.mu.Lock() defer m.mu.Unlock() reply.Stdout = make([]byte, m.stdout.Len()) if _, err := m.stdout.Read(reply.Stdout); err != nil { return fmt.Errorf("failed to read stdout: %w", err) } reply.Stderr = make([]byte, m.stderr.Len()) if _, err := m.stderr.Read(reply.Stderr); err != nil { return fmt.Errorf("failed to read stderr: %w", err) } // a signal value of 0 can be sent to the process to probe whether it's still alive // or not, it triggers no handling in the receiver process; apart from this, there // aren't many other ways of checking the process health (apart from poking the pid directly) reply.Alive = true if err := m.proc.Signal(syscall.Signal(0)); err != nil { if errors.Is(err, os.ErrProcessDone) { reply.Alive = false return nil } return fmt.Errorf("failed to send signal: %w", err) } return nil } func (m *monitor) Kill(_ int, _ *interface{}) error { log.Print("got a call for: kill") if err := m.proc.Signal(syscall.SIGKILL); err != nil { return fmt.Errorf("failed to send SIGKILL: %w", err) } return nil } func (m *monitor) Wait(_ int, _ *interface{}) error { log.Print("got a call for: wait") m.reaper.Signal() return nil } const sockFormat = "/tmp/exec_bin_sock_%d" type MonitorServer struct { addr string mon *monitor http *http.Server } func NewMonitorServer(proc *os.Process, stdout *SafeBuffer, stderr *SafeBuffer, reap *SafeSignal) *MonitorServer { addr := fmt.Sprintf(sockFormat, proc.Pid) mon := newMonitor(proc, stdout, stderr, reap) return &MonitorServer{addr, mon, nil} } func (m *MonitorServer) Serve() error { log.Printf("starting monitor...") if err := os.RemoveAll(m.addr); err != nil { return fmt.Errorf("failed to clear lingering socket %s: %w", m.addr, err) } listener, err := net.Listen("unix", m.addr) if err != nil { return fmt.Errorf("failed to listen on socket %s: %w", m.addr, err) } defer listener.Close() rpcServer := rpc.NewServer() if err := rpcServer.RegisterName("api", m.mon); err != nil { return fmt.Errorf("failed to register rpc api: %v", err) } log.Printf("starting RPC server at: %s", m.addr) m.http = &http.Server{ Addr: m.addr, Handler: rpcServer, } return m.http.Serve(listener) } func (m *MonitorServer) Shutdown() error { log.Printf("shutting down monitor...") if err := os.RemoveAll(m.addr); err != nil { return fmt.Errorf("failed to remove any socket %s: %w", m.addr, err) } if m.http != nil { // dont care about cancellation context return m.http.Shutdown(context.Background()) } return nil } type ErrCantConnect struct { w error } func (e ErrCantConnect) Error() string { return e.w.Error() } func (e ErrCantConnect) Unwrap() error { return e.w } type MonitorClient struct { addr string } func NewMonitorClient(pid int) *MonitorClient { addr := fmt.Sprintf(sockFormat, pid) return &MonitorClient{addr} } func (m *MonitorClient) Wait() error { client, err := rpc.DialHTTP("unix", m.addr) if err != nil { return &ErrCantConnect{fmt.Errorf("failed to connect to %s: %w", m.addr, err)} } defer client.Close() var reply interface{} if err := client.Call("api.Wait", 0, &reply); err != nil { return fmt.Errorf("failed to call rpc method: %w", err) } return nil } func (m *MonitorClient) Poll() (*PollResponse, error) { client, err := rpc.DialHTTP("unix", m.addr) if err != nil { return nil, &ErrCantConnect{fmt.Errorf("failed to connect to %s: %w", m.addr, err)} } defer client.Close() var reply PollResponse if err := client.Call("api.Poll", 0, &reply); err != nil { return nil, fmt.Errorf("failed to call rpc method: %w", err) } return &reply, nil } func (m *MonitorClient) Kill() error { client, err := rpc.DialHTTP("unix", m.addr) if err != nil { return &ErrCantConnect{fmt.Errorf("failed to connect to %s: %w", m.addr, err)} } defer client.Close() var reply interface{} if err := client.Call("api.Kill", 0, &reply); err != nil { return fmt.Errorf("failed to call rpc method: %w", err) } return nil }