agent/session/shell/shell.go (229 lines of code) (raw):

package shell import ( "bufio" "bytes" "context" "encoding/json" "errors" "fmt" "time" "unicode/utf8" "github.com/aliyun/aliyun_assist_client/agent/log" "github.com/aliyun/aliyun_assist_client/agent/session/channel" "github.com/aliyun/aliyun_assist_client/agent/session/message" "github.com/aliyun/aliyun_assist_client/agent/util" "github.com/aliyun/aliyun_assist_client/agent/util/osutil" "github.com/aliyun/aliyun_assist_client/thirdparty/sirupsen/logrus" ) const ( Ok = "Ok" Init_channel_failed = "Init_channel_failed" Open_channel_failed = "Open_channel_failed" Session_id_duplicate = "Session_id_duplicate" Process_data_error = "Process_data_error" Open_pty_failed = "Open_pty_failed" Timeout = "Timeout" Notified = "Notified" Unknown_error = "Unknown_error" ) type SizeData struct { Cols uint32 `json:"cols"` Rows uint32 `json:"rows"` } const ( sendPackageSize = 1024 // 发送的payload大小上限,单位 B defaultSendSpeed = 200 // 默认的最大数据发送速率,单位 kbps defaultSendInterval = 1000 / (defaultSendSpeed * 1024 / 8 / sendPackageSize) // writeloop的循环间隔时间 单位ms ) func NewShellPlugin(id string, cmdContent string, username string, passwordName string, flowLimit int) *ShellPlugin { plugin := &ShellPlugin{ id: id, cmdContent: cmdContent, username: username, passwordName: passwordName, sendInterval: defaultSendInterval, logger: log.GetLogger().WithFields(logrus.Fields{ "sessionType": "shell", "sessionId": id, }), } if flowLimit > 0 { plugin.sendInterval = 1000 / (flowLimit / 8 / sendPackageSize) } else { flowLimit = defaultSendSpeed * 1024 } plugin.logger.Infof("Init send speed, speed[%d]bps sendInterval[%d]ms\n", flowLimit, plugin.sendInterval) return plugin } func (p *ShellPlugin) Execute(dataChannel channel.ISessionChannel, cancelFlag util.CancelFlag) (errorCode string, pluginErr error) { p.dataChannel = dataChannel // p.exitCtx and p.exitFunc are initialized in Execute() instead of NewPortPlugin() to ensure // that p.exitCtx can be released eventually. // Although InputStreamMessageHandler() may be called before Execute(), it will wait until // p.conn is successfully created before starting to process data (possibly calling p.exitFunc), // so p.exitFunc will not be called before it be initialized. p.exitCtx, p.exitFunc = context.WithCancelCause(context.Background()) errorCode = Ok defer func() { p.logger.Infoln("stop in run ShellPlugin") if err := p.stop(); err != nil { p.logger.Errorf("Error occurred while closing pty: %v", err) } if err := recover(); err != nil { p.logger.Errorf("Error occurred while executing plugin %s: \n%v", p.id, err) errorCode = Unknown_error if v, ok := err.(error); ok { pluginErr = v } else { pluginErr = fmt.Errorf(fmt.Sprint(err)) } } }() p.logger.Infoln("start pty") pluginErr = StartPty(p) if pluginErr != nil { errorString := fmt.Sprintf("Unable to start shell: %s", pluginErr) p.logger.Errorln(errorString) errorCode = Open_pty_failed return } p.logger.Infoln("start pty success") go func() { select { case <-cancelFlag.C(): cancelState := cancelFlag.State() if cancelState == util.Canceled { p.exitFunc(errors.New(Timeout)) } else { p.exitFunc(errors.New(Notified)) } p.logger.Debugf("Cancel flag set to %v in session", cancelState) p.logger.Info("The session was cancelled") case <-p.exitCtx.Done(): cancelFlag.Set(util.ShutDown) } }() go func() { p.writePump() }() p.logger.Infof("Plugin %s started", p.id) <-p.exitCtx.Done() errorCode = context.Cause(p.exitCtx).Error() p.logger.Infoln("Plugin done", p.id, errorCode) if osutil.GetOsType() == osutil.OSLinux || osutil.GetOsType() == osutil.OSFreebsd { p.waitPid() } return } func (p *ShellPlugin) writePump() { defer func() { if err := recover(); err != nil { p.logger.Println("WritePump thread crashed with message: \n", err) } }() stdoutBytes := make([]byte, sendPackageSize) reader := bufio.NewReader(p.stdout) // Wait for all input commands to run. time.Sleep(time.Second) var unprocessedBuf bytes.Buffer for { select { case <-p.exitCtx.Done(): return default: stdoutBytesLen, err := reader.Read(stdoutBytes) if err != nil { p.logger.Debugf("Failed to read from pty master: %s", err) p.exitFunc(errors.New(Ok)) return } // unprocessedBuf contains incomplete utf8 encoded unicode bytes returned after processing of stdoutBytes if unprocessedBuf, err = p.processStdoutData(stdoutBytes, stdoutBytesLen, unprocessedBuf); err != nil { p.logger.Errorf("Error processing stdout data, %v", err) p.exitFunc(errors.New(Process_data_error)) return } // Wait for stdout to process more data time.Sleep(time.Duration(p.sendInterval) * time.Millisecond) } } } // processStdoutData reads utf8 encoded unicode characters from stdoutBytes and sends it over websocket channel. func (p *ShellPlugin) processStdoutData( stdoutBytes []byte, stdoutBytesLen int, unprocessedBuf bytes.Buffer) (bytes.Buffer, error) { // append stdoutBytes to unprocessedBytes and then read rune from appended bytes to send it over websocket channel unprocessedBytes := unprocessedBuf.Bytes() unprocessedBytes = append(unprocessedBytes[:], stdoutBytes[:stdoutBytesLen]...) runeReader := bufio.NewReader(bytes.NewReader(unprocessedBytes)) var processedBuf bytes.Buffer unprocessedBytesLen := len(unprocessedBytes) i := 0 for i < unprocessedBytesLen { // read stdout bytes as utf8 encoded unicode character stdoutRune, stdoutRuneLen, err := runeReader.ReadRune() if err != nil { return processedBuf, fmt.Errorf("failed to read rune from reader: %s", err) } // Invalid utf8 encoded character results into RuneError. if stdoutRune == utf8.RuneError { // If invalid character is encountered within last 3 bytes of buffer (utf8 takes 1-4 bytes for a unicode character), // then break the loop and leave these bytes in unprocessed buffer for them to get processed later with more bytes returned by stdout. if unprocessedBytesLen-i < utf8.UTFMax { runeReader.UnreadRune() break } // If invalid character is encountered beyond last 3 bytes of buffer, then the character at ith position is invalid utf8 character. // Add invalid byte at ith position to processedBuf in such case and return to client to handle display of invalid character. processedBuf.Write(unprocessedBytes[i : i+1]) } else { processedBuf.WriteRune(stdoutRune) } i += stdoutRuneLen } if p.dataChannel != nil { if err := p.dataChannel.SendStreamDataMessage(processedBuf.Bytes()); err != nil { return processedBuf, fmt.Errorf("unable to send stream data message: %s", err) } } // log.GetLogger().Println("data output: ", string(processedBuf.Bytes())) // return incomplete utf8 encoded unicode bytes to be processed with next batch of stdoutBytes unprocessedBuf.Reset() if i < unprocessedBytesLen { unprocessedBuf.Write(unprocessedBytes[i:unprocessedBytesLen]) } return unprocessedBuf, nil } func (p *ShellPlugin) InputStreamMessageHandler(streamDataMessage message.Message) error { if p.stdin == nil || p.stdout == nil { // This is to handle scenario when cli/console starts sending size data but pty has not been started yet // Since packets are rejected, cli/console will resend these packets until pty starts successfully in separate thread p.logger.Error("Pty unavailable. Reject incoming message packet") return nil } switch streamDataMessage.MessageType { case message.InputStreamDataMessage: // log.GetLogger().Traceln("Input message received: ", streamDataMessage.Payload) if err := p.onInputStreamData(streamDataMessage.Payload); err != nil { return err } case message.SetSizeDataMessage: var size SizeData if err := json.Unmarshal(streamDataMessage.Payload, &size); err != nil { p.logger.Errorf("Invalid size message: %s", err) return err } // log.GetLogger().Tracef("Resize data received: cols: %d, rows: %d", size.Cols, size.Rows) if err := p.SetSize(size.Cols, size.Rows); err != nil { p.logger.Errorf("Unable to set pty size: %s", err) return err } case message.StatusDataMessage: if len(streamDataMessage.Payload) > 0 { code, err := message.BytesToIntU(streamDataMessage.Payload[0:1]) if err == nil { switch code { case 7: // 设置agent的发送速率 speed, err := message.BytesToIntU(streamDataMessage.Payload[1:]) // speed 单位是 bps if speed == 0 { break } if err != nil { p.logger.Errorf("Invalid flowLimit: %s", err) return err } p.sendInterval = 1000 / (speed / 8 / sendPackageSize) p.logger.Infof("Set send speed, speed[%d]bps sendInterval[%d]ms\n", speed, p.sendInterval) case 5: p.logger.Info("Exit due to receiving a packet with a close status") p.exitFunc(errors.New(Ok)) } } else { p.logger.Errorf("Parse status code err: %s", err) } } case message.CloseDataChannel: p.logger.Info("Exit due to receiving CloseDataChannel packet") p.exitFunc(errors.New(Ok)) } return nil }