agent/util/process/process.go (270 lines of code) (raw):

package process import ( "errors" "fmt" "io" "os" "os/exec" "strconv" "strings" "sync" "time" "github.com/aliyun/aliyun_assist_client/agent/log" "github.com/aliyun/aliyun_assist_client/agent/util/process/processcollection" "github.com/aliyun/aliyun_assist_client/common/executil" "github.com/aliyun/aliyun_assist_client/thirdparty/sirupsen/logrus" ) const ( Success int = iota Fail Timeout groupsIdentifier = "groups=" ) type WaitProcessResult struct { processState *os.ProcessState err error } type CmdOption func(*exec.Cmd) error type ProcessCmd struct { canceledChan chan bool command *exec.Cmd user_name string password string homeDir string env []string commandOptions []CmdOption collection processcollection.ProcessCollection } func NewProcessCmd(options ...CmdOption) *ProcessCmd { p := &ProcessCmd{ commandOptions: options, } return p } func NewProcessCmdWithProcessTree(groupName string, options ...CmdOption) (*ProcessCmd, error) { p := &ProcessCmd{ commandOptions: options, } p.collection, _ = processcollection.CreateProcessTree(groupName) return p, nil } type cancellableAndSafeWriter struct { baseWriter io.Writer cancelChannel chan bool cancelled bool lock sync.Mutex } type ReadCallbackFunc func(stdoutWriter io.Reader, stderrWriter io.Reader) func (p *ProcessCmd) Cancel() error { if p.collection != nil { return p.collection.KillAll() } else if p.command != nil { if err := p.command.Process.Kill(); err != nil { return fmt.Errorf("Failed to kill process %d: %s", p.command.Process.Pid, err) } } return nil } func (p *ProcessCmd) SetUserInfo(name string) { p.user_name = name } func (p *ProcessCmd) SetPasswordInfo(password string) { p.password = password } func (p *ProcessCmd) SetHomeDir(homeDir string) { p.homeDir = homeDir } func (p *ProcessCmd) SetEnv(env []string) { p.env = env } func (p *ProcessCmd) SyncRunSimple(commandName string, commandArguments []string, timeOut int) error { p.command = executil.Command(commandName, commandArguments...) logger := log.GetLogger().WithFields(logrus.Fields{ "command": p.command.Args, "timeout": timeOut, }) if err := p.prepareProcess(); err != nil { return err } if err := p.command.Start(); err != nil { logger.WithError(err).Errorln("error occurred starting the command") return errors.New("error occurred starting the command") } if p.collection != nil { if err := p.collection.AddProcess(p.command.Process); err != nil { logger.WithError(err).Error("add process into collection failed") p.command.Process.Kill() return fmt.Errorf("add process into collection failed: %v", err) } logger.Infof("add process %d into collection", p.command.Process.Pid) } finished := make(chan error, 1) go func() { finished <- p.command.Wait() }() var err error select { case err = <-finished: logger.Infoln("Process completed.") if err != nil { logger.WithError(err).Infoln("error in run command") } case <-time.After(time.Duration(timeOut) * time.Second): logger.Errorln("Timeout in run command.") err = errors.New("cmd run timeout") if p.collection != nil { if err := p.collection.KillAll(); err != nil { logger.WithError(err).Error("kill all process in collection failed") } } else { p.command.Process.Kill() } } if p.collection != nil { p.collection.Dispose() } return err } func (p *ProcessCmd) SyncRun( workingDir string, commandName string, commandArguments []string, stdoutWriter io.Writer, stderrWriter io.Writer, stdinReader io.Reader, callbackFunc ReadCallbackFunc, timeOut int) (exitCode int, status int, err error) { status = Success exitCode = 0 p.command = executil.Command(commandName, commandArguments...) p.command.Stdout = stdoutWriter p.command.Stderr = stderrWriter p.command.Stdin = stdinReader p.command.Dir = workingDir p.command.Env = p.env if err := p.prepareProcess(); err != nil { return 0, Fail, err } if p.user_name != "" { if err := p.addCredential(); err != nil { return 0, Fail, err } } if err = p.command.Start(); err != nil { log.GetLogger().Errorln("error occurred starting the command", err) exitCode = 1 return exitCode, Fail, err } if p.collection != nil { if err := p.collection.AddProcess(p.command.Process); err != nil { log.GetLogger().WithError(err).Error("add process into collection failed") p.command.Process.Kill() return 1, Fail, fmt.Errorf("add process into collection failed: %v", err) } else { log.GetLogger().Infof("add process %d into collection %s success", p.command.Process.Pid, p.collection.Name()) } } finished := make(chan WaitProcessResult, 1) go func() { processState, err := p.command.Process.Wait() finished <- WaitProcessResult{ processState: processState, err: err, } }() var timeoutChannel <-chan time.Time = nil if timeOut > 0 { timer := time.NewTimer(time.Duration(timeOut) * time.Second) defer timer.Stop() timeoutChannel = timer.C } select { case waitProcessResult := <-finished: log.GetLogger().Println("Command completed.", commandName) if waitProcessResult.processState != nil { if waitProcessResult.err != nil { log.GetLogger().WithFields(logrus.Fields{ "processState": waitProcessResult.processState, }).WithError(waitProcessResult.err).Error("os.Process.Wait() returns error with valid process state") } exitCode = waitProcessResult.processState.ExitCode() // Sleep 200ms to allow remaining data to be copied back time.Sleep(time.Duration(200) * time.Millisecond) // Explicitly break select statement in case timer also times out break } else { exitCode = 1 return exitCode, Fail, waitProcessResult.err } case <-timeoutChannel: log.GetLogger().Errorln("Timeout in run command.", commandName) exitCode = 1 status = Timeout err = errors.New("timeout") if p.collection != nil { if err := p.collection.KillAll(); err != nil { log.GetLogger().WithError(err).Error("kill all process in collection failed") } } else { p.command.Process.Kill() } } if p.user_name != "" { p.removeCredential() } if p.collection != nil { p.collection.Dispose() } return exitCode, status, err } func (p *ProcessCmd) Pid() int { if p.command == nil || p.command.Process == nil { return -1 } return p.command.Process.Pid } func GetUserCredentials(sessionUser string) (uint32, uint32, []uint32, error) { uidCmdArgs := append([]string{"-c"}, fmt.Sprintf("id -u %s", sessionUser)) cmd := executil.Command("sh", uidCmdArgs...) out, err := cmd.Output() if err != nil { log.GetLogger().Errorf("Failed to retrieve uid for %s: %v", sessionUser, err) return 0, 0, nil, err } uid, err := strconv.Atoi(strings.TrimSpace(string(out))) if err != nil { log.GetLogger().Errorf("%s not found: %v", sessionUser, err) return 0, 0, nil, err } gidCmdArgs := append([]string{"-c"}, fmt.Sprintf("id -g %s", sessionUser)) cmd = executil.Command("sh", gidCmdArgs...) out, err = cmd.Output() if err != nil { log.GetLogger().Errorf("Failed to retrieve gid for %s: %v", sessionUser, err) return 0, 0, nil, err } gid, err := strconv.Atoi(strings.TrimSpace(string(out))) if err != nil { log.GetLogger().Errorf("%s not found: %v", sessionUser, err) return 0, 0, nil, err } // Get the list of associated groups groupNamesCmdArgs := append([]string{"-c"}, fmt.Sprintf("id %s", sessionUser)) cmd = executil.Command("sh", groupNamesCmdArgs...) out, err = cmd.Output() if err != nil { log.GetLogger().Errorf("Failed to retrieve groups for %s: %v", sessionUser, err) return 0, 0, nil, err } // Example format of output: uid=1873601143(ssm-user) gid=1873600513(domain users) groups=1873600513(domain users),1873601620(joiners),1873601125(aws delegated add workstations to domain users) // Extract groups from the output groupsIndex := strings.Index(string(out), groupsIdentifier) var groupIds []uint32 if groupsIndex > 0 { // Extract groups names and ids from the output groupNamesAndIds := strings.Split(string(out)[groupsIndex+len(groupsIdentifier):], ",") // Extract group ids from the output for _, value := range groupNamesAndIds { groupId, err := strconv.Atoi(strings.TrimSpace(value[:strings.Index(value, "(")])) if err != nil { log.GetLogger().Errorf("Failed to retrieve group id from %s: %v", value, err) return 0, 0, nil, err } groupIds = append(groupIds, uint32(groupId)) } } return uint32(uid), uint32(gid), groupIds, nil }