agent/taskengine/session.go (179 lines of code) (raw):

package taskengine import ( "errors" "fmt" "strconv" "time" "github.com/aliyun/aliyun_assist_client/agent/log" "github.com/aliyun/aliyun_assist_client/agent/metrics" "github.com/aliyun/aliyun_assist_client/agent/session/channel" "github.com/aliyun/aliyun_assist_client/agent/session/port" "github.com/aliyun/aliyun_assist_client/agent/session/shell" "github.com/aliyun/aliyun_assist_client/agent/taskengine/models" "github.com/aliyun/aliyun_assist_client/agent/util" ) type SessionTask struct { taskId string sessionId string websocketUrl string cmdContent string username string passwordName string targetHost string portNumber string flowLimit int sessionChannel *channel.SessionChannel shellPlugin *shell.ShellPlugin portPlugin *port.PortPlugin cancelFlag util.CancelFlag } func NewSessionTask(sessionId string, websocketUrl string, taskId string, cmdContent string, username string, passwordName string, targetHost string, portNumber string, flowLimit int) *SessionTask { task := &SessionTask{ sessionId: sessionId, taskId: taskId, websocketUrl: websocketUrl, cmdContent: cmdContent, passwordName: passwordName, username: username, targetHost: targetHost, portNumber: portNumber, flowLimit: flowLimit, cancelFlag: util.NewChanneledCancelFlag(), } return task } func ReportSessionResult(taskID string, status string) { url := util.GetSessionStatusService() reportStatus := "Failed" if status == shell.Ok || status == shell.Notified || status == shell.Timeout { reportStatus = "Success" } param := fmt.Sprintf("?channelId=%s&status=%s&errorcode=%s", taskID, reportStatus, status) url += param log.GetLogger().Printf("post = %s", url) _, err := util.HttpPost(url, "", "text") if err != nil { metrics.GetTaskFailedEvent( "errormsg", fmt.Sprintf("report session result err: %s", err.Error()), "url", url, "taskid", taskID, ).ReportEvent() log.GetLogger().Printf("HttpPost url %s error:%s ", url, err.Error()) } } func (sessionTask *SessionTask) isPortForwardTask() bool { return sessionTask.portNumber != "" } func (sessionTask *SessionTask) runTask() (string, error) { ret := GetSessionFactory().ContainsTask(sessionTask.sessionId) if ret == true { log.GetLogger().Errorln("NewSessionChannel failed") return shell.Session_id_duplicate, errors.New("NewSessionChannel failed") } if sessionTask.isPortForwardTask() { port_num, _ := strconv.Atoi(sessionTask.portNumber) sessionTask.portPlugin = port.NewPortPlugin(sessionTask.sessionId, sessionTask.targetHost, port_num, sessionTask.flowLimit) } else { sessionTask.shellPlugin = shell.NewShellPlugin(sessionTask.sessionId, sessionTask.cmdContent, sessionTask.username, sessionTask.passwordName, sessionTask.flowLimit) } GetSessionFactory().AddSessionTask(sessionTask) host := util.GetServerHost() if host == "" { return shell.Init_channel_failed, errors.New("No available host") } websocketUrl := "wss://" + host + "/luban/session/backend?channelId=" + sessionTask.sessionId log.GetLogger().Infoln("url: ", websocketUrl) var err error var session_channel *channel.SessionChannel if sessionTask.isPortForwardTask() { session_channel, err = channel.NewSessionChannel(websocketUrl, sessionTask.sessionId, sessionTask.portPlugin.InputStreamMessageHandler, sessionTask.cancelFlag) } else { session_channel, err = channel.NewSessionChannel(websocketUrl, sessionTask.sessionId, sessionTask.shellPlugin.InputStreamMessageHandler, sessionTask.cancelFlag) sessionTask.sessionChannel = session_channel } if err != nil { log.GetLogger().Errorln("NewSessionChannel failed", err) return shell.Init_channel_failed, fmt.Errorf("NewSessionChannel failed: %v", err) } sessionTask.sessionChannel = session_channel err = session_channel.Open() if err != nil { log.GetLogger().Errorln("NewSessionChannel failed", err) return shell.Open_channel_failed, fmt.Errorf("NewSessionChannel failed: %v", err) } done := make(chan int, 1) error_code := shell.Ok go func() { time.Sleep(1 * time.Second) if sessionTask.isPortForwardTask() { log.GetLogger().Infoln("run portPlugin") error_code, err = sessionTask.portPlugin.Execute(session_channel, sessionTask.cancelFlag) } else { log.GetLogger().Infoln("run shellPlugin") error_code, err = sessionTask.shellPlugin.Execute(session_channel, sessionTask.cancelFlag) } done <- 1 }() select { case <-done: log.GetLogger().Println("shell end", sessionTask.sessionId) case <-time.After(time.Duration(3600*3) * time.Second): log.GetLogger().Println("shell timeout", sessionTask.sessionId) error_code = shell.Timeout } return error_code, err } func DoSessionTask(tasks []models.SessionTaskInfo) { go func() { for _, s := range tasks { session := NewSessionTask(s.SessionId, s.WebsocketUrl, s.SessionId, s.CmdContent, s.Username, s.Password, s.TargetHost, s.PortNumber, s.FlowLimit) session.RunTask(s.SessionId) } }() } func (sessionTask *SessionTask) RunTask(taskid string) error { log.GetLogger().Infoln("run task", taskid, sessionTask.sessionId) code, err := sessionTask.runTask() ReportSessionResult(taskid, code) if sessionTask.sessionChannel != nil { sessionTask.sessionChannel.Close() } GetSessionFactory().RemoveTask(sessionTask.sessionId) if err != nil { metrics.GetTaskFailedEvent( "errormsg", err.Error(), "taskid", sessionTask.taskId, "sessionid", sessionTask.sessionId, "code", code, "wsURL", sessionTask.websocketUrl, ).ReportEvent() metrics.GetSessionFailedEvent( "sessionId", sessionTask.sessionId, "errormsg", err.Error(), "taskid", sessionTask.taskId, "code", code, "wsURL", sessionTask.websocketUrl, ).ReportEvent() } return err } func (sessionTask *SessionTask) StopTask() error { log.GetLogger().Infoln("stop task", sessionTask.taskId) if sessionTask.shellPlugin != nil || sessionTask.portPlugin != nil { sessionTask.cancelFlag.Set(util.Completed) } else { log.GetLogger().Errorln("sesison plugin is invalid") } return nil }