agent/session/plugin/client/client.go (669 lines of code) (raw):

package client import ( "bufio" "bytes" "encoding/binary" "errors" "fmt" "io" "net" "net/http" "os" "runtime" "runtime/debug" "sync" "time" "github.com/aliyun/aliyun_assist_client/agent/session/plugin/log" "github.com/aliyun/aliyun_assist_client/agent/session/plugin/message" "go.uber.org/atomic" "github.com/containerd/console" "github.com/creack/goselect" "github.com/gorilla/websocket" tsize "github.com/kopoli/go-terminal-size" "github.com/sirupsen/logrus" "golang.org/x/crypto/ssh/terminal" ) const ( committedSuicide = iota killed ) const ( sendPackageSize = 2048 // 发送的数据包大小上限,单位 B defaultSendSpeed = 200 // 默认的最大数据发送速率,单位 kbps defaultSendInterval = 1000 / (defaultSendSpeed * 1024 / 8 / sendPackageSize) // writeloop的循环间隔时间 单位ms minimumIdleTimeout = 60 ) // 状态码为5时可能的错误码 const ( EXIT = "Exit" AGENT_TIMEOUT = "AgentTimeout" INIT_CHANNEL_FAILED = "InitChannelFailed" OPEN_CHANNEL_FAILED = "OpenChannelFailed" SESSIONID_DUPLICATED = "SessionIdDuplicated" PROCESS_DATA_ERROR = "ProcessDataError" OPEN_PTY_FAILED = "OpenPtyFailed" FLOW_EXCEED_LIMIT = "FlowExceedLimit" ) type Client struct { Dialer *websocket.Dialer Conn *websocket.Conn URL string token string Connected bool Output io.Writer Input io.ReadCloser WriteMutex *sync.Mutex EscapeKeys []byte PortForward bool // true means the client is for portforward poison chan bool StreamDataSequenceNumber int64 rawmode bool //true means not use console mode verbosemode bool real_connected bool sendInterval int // ms, interval between writeLoop, for limit send speed // If the client is idle for more than idletimeout seconds, the connection // will be closed. // idleTimeout only works for session, the minimum is minimumIdleTimeout(60s). idleTimeout int32 // seond // lastDataTimestampOffset is the difference between the timestamp of the // last received or sent data packet and the timestamp of the Client startup. lastDataTimestampOffset atomic.Int32 // second // startTimestamp is the timestamp of the Client startup. startTimestamp time.Time } func NewClient(inputURL string, input io.ReadCloser, output io.Writer, portForward bool, token string, rawmode bool, verbosemode bool, idleTimeout int32) (*Client, error) { c := &Client{ Dialer: &websocket.Dialer{}, URL: inputURL, token: token, WriteMutex: &sync.Mutex{}, Output: output, // os.Stdout Input: input, // os.Stdin PortForward: portForward, StreamDataSequenceNumber: 0, rawmode: rawmode, real_connected: false, verbosemode: verbosemode, poison: make(chan bool), sendInterval: defaultSendInterval, startTimestamp: time.Now(), } if idleTimeout > 0 && idleTimeout < minimumIdleTimeout { idleTimeout = minimumIdleTimeout } c.idleTimeout = idleTimeout return c, nil } // Connect tries to dial a websocket server func (c *Client) Connect() error { // Open WebSocket connection logrus.Debugln("Connecting to websocket: ", c.URL) header := http.Header{} header.Add("x-acs-session-token", c.token) conn, _, err := c.Dialer.Dial(c.URL, header) if err != nil { return err } c.Conn = conn c.Connected = true // Initialize message types for gotty // go c.pingLoop() return nil } func (c *Client) pingLoop() { for { if c.Connected { logrus.Debugf("Sending ping") err := c.Conn.WriteMessage(websocket.PingMessage, []byte("keepalive")) if err != nil { logrus.Warnf("c.write: %v", err) } } time.Sleep(30 * time.Second) } } func (c *Client) Loop() error { if !c.Connected { err := c.Connect() if err != nil { return err } } if !c.rawmode { if runtime.GOOS == "darwin" { stdin := int(os.Stdin.Fd()) log.GetLogger().Infoln("under darwin") oldState, err := terminal.MakeRaw(stdin) if err != nil { log.GetLogger().Errorln(err) fmt.Printf("capture stdin failed %s\r\n", err) } defer func() { if e := recover(); e != nil { fmt.Fprintln(os.Stderr, e) fmt.Fprintln(os.Stderr, string(debug.Stack())) } terminal.Restore(stdin, oldState) if !c.PortForward { os.Exit(1) } }() } else { term, err := console.ConsoleFromFile(os.Stdout) if err != nil { log.GetLogger().Errorln(err) return fmt.Errorf("os.Stdout is not a valid terminal") } err = term.SetRaw() if err != nil { return fmt.Errorf("Error setting raw terminal: %v", err) } defer func() { if e := recover(); e != nil { fmt.Fprintln(os.Stderr, e) fmt.Fprintln(os.Stderr, string(debug.Stack())) } term.Reset() if !c.PortForward { os.Exit(1) } }() } } else { log.GetLogger().Infoln("under rawmode") defer func() { if e := recover(); e != nil { fmt.Fprintln(os.Stderr, e) fmt.Fprintln(os.Stderr, string(debug.Stack())) } if !c.PortForward { os.Exit(1) } }() } wg := &sync.WaitGroup{} wg.Add(1) go c.termsizeLoop(wg) if !c.rawmode { wg.Add(1) go c.writeLoop(wg) } else { wg.Add(1) go c.writeLoopRawMode(wg) } wg.Add(1) go c.readLoop(wg) /* Wait for all of the above goroutines to finish */ //wg.Wait() <-c.poison logrus.Debug("Client.Loop() exiting") return nil } func (c *Client) termsizeLoop(wg *sync.WaitGroup) int { defer wg.Done() fname := "termsizeLoop" for !c.real_connected { time.Sleep(time.Millisecond * 500) } width, height := -1, -1 // repeating this loop for every 500ms resizeTimer := time.NewTimer(time.Millisecond * 500) for { s, err := tsize.GetSize() if err != nil { log.GetLogger().Warning("get terminal size fail: ", err) return 0 } if s.Width != width || s.Height != height { width = s.Width height = s.Height buf := new(bytes.Buffer) binary.Write(buf, binary.LittleEndian, int16(height)) binary.Write(buf, binary.LittleEndian, int16(width)) b := buf.Bytes() if err = c.SendResizeDataMessage(b); err != nil { log.GetLogger().Warnf("ws.WriteMessage failed: %v", err) } } select { case <-c.poison: /* Somebody poisoned the well; die */ return die(fname, c.poison) case <-resizeTimer.C: resizeTimer.Reset(time.Millisecond * 500) } } } func bytesToIntU(b []byte) (int, error) { if len(b) == 3 { b = append([]byte{0}, b...) } bytesBuffer := bytes.NewBuffer(b) switch len(b) { case 1: var tmp uint8 err := binary.Read(bytesBuffer, binary.BigEndian, &tmp) return int(tmp), err case 2: var tmp uint16 err := binary.Read(bytesBuffer, binary.BigEndian, &tmp) return int(tmp), err case 4: var tmp uint32 err := binary.Read(bytesBuffer, binary.BigEndian, &tmp) return int(tmp), err default: return 0, fmt.Errorf("%s", "BytesToInt bytes lenth is invalid!") } } func (c *Client) ProcessStatusDataChannel(payload []byte) error { if c.verbosemode { log.GetLogger().Infoln("read status data: ", payload) } code, err := bytesToIntU(payload[0:1]) if err == nil { if code == 2 { //建立连接失败 log.GetLogger().Errorln("connect failed code 2") c.Output.Write(payload) return errors.New("Failed to connect. code 2") } else if code == 5 { //关闭连接 errorCode := string(payload[1:]) tipStr := errorCode log.GetLogger().Errorln("connect failed code 5") switch errorCode { case EXIT: tipStr = fmt.Sprint(EXIT, ": session closed.") break case AGENT_TIMEOUT: tipStr = fmt.Sprint(AGENT_TIMEOUT, ": session closed for agent timeout.") break case INIT_CHANNEL_FAILED: tipStr = fmt.Sprint(INIT_CHANNEL_FAILED, ": session closed for init channel failed") break case OPEN_CHANNEL_FAILED: tipStr = fmt.Sprint(OPEN_CHANNEL_FAILED, ": session closed for open channel failed.") break case SESSIONID_DUPLICATED: tipStr = fmt.Sprint(SESSIONID_DUPLICATED, ": session closed for sessionId is duplicated.") break case PROCESS_DATA_ERROR: tipStr = fmt.Sprint(PROCESS_DATA_ERROR, ": session closed for error while process data.") break case OPEN_PTY_FAILED: tipStr = fmt.Sprint(OPEN_PTY_FAILED, ": session closed for open pty failed.") break case FLOW_EXCEED_LIMIT: tipStr = fmt.Sprint(FLOW_EXCEED_LIMIT, ": session closed for the flow exceeds limit.") break } log.GetLogger().Errorln("connect failed code 5:", tipStr) fmt.Println(tipStr) return errors.New("Connection closed. code 5") } else if code == 3 { } else if code == 7 { // 设置client的发送速率 speed, err := bytesToIntU(payload[1:]) // speed 单位是 bps if speed != 0 { if err != nil { return err } c.sendInterval = 1000 / (speed / 8 / sendPackageSize) log.GetLogger().Infof("Set send speed, speed[%d]bps sendInterval[%d]ms\n", speed, c.sendInterval) } } } return nil } func (c *Client) readLoop(wg *sync.WaitGroup) int { defer wg.Done() fname := "readLoop" type MessageNonBlocking struct { Msg message.Message Err error } msgChan := make(chan MessageNonBlocking) for { go func() { defer func() { if r := recover(); r != nil { logrus.Debug("readLoop returned so msgChan closed", r) } }() _, data, err := c.Conn.ReadMessage() if c.verbosemode { log.GetLogger().Infoln("read msg: ", string(data)) } streamDataMessage := message.Message{} if err == nil { if err = streamDataMessage.Deserialize(data); err != nil { log.GetLogger().Errorf("Cannot deserialize raw message, err: %v.", err) } } else { log.GetLogger().Errorln("read msg err") openPoison(fname, c.poison) } if c.verbosemode { log.GetLogger().Infoln("read msg num : ", streamDataMessage.SequenceNumber) } msgChan <- MessageNonBlocking{Msg: streamDataMessage, Err: err} // time.Sleep(time.Second * 1) // msgChan <- MessageNonBlocking{Data: []byte("c"), Err: nil} }() select { case <-c.poison: close(msgChan) return die(fname, c.poison) case msg := <-msgChan: if msg.Err != nil { log.GetLogger().Errorln("read msg err", msg.Err) if _, ok := msg.Err.(*websocket.CloseError); !ok { log.GetLogger().Warnf("c.Conn.ReadMessage: %v", msg.Err) } return openPoison(fname, c.poison) } if msg.Msg.Validate() != nil { log.GetLogger().Errorln("An error has occured, msg is invalid") return openPoison(fname, c.poison) } switch msg.Msg.MessageType { case message.OutputStreamDataMessage: // data c.real_connected = true c.Output.Write(msg.Msg.Payload) c.lastDataTimestampOffset.Store(int32(time.Since(c.startTimestamp).Seconds())) break case message.StatusDataChannel: // data if c.ProcessStatusDataChannel(msg.Msg.Payload) != nil { return openPoison(fname, c.poison) } break default: // logrus.Warnf("Unhandled protocol message") } } } return 0 } type exposeFd interface { Fd() uintptr } func (c *Client) writeLoopRawMode(wg *sync.WaitGroup) int { defer wg.Done() fname := "writeLoop" buff := make([]byte, sendPackageSize) br := bufio.NewReader(c.Input) if c.PortForward { // wait agent build local connect time.Sleep(time.Duration(3) * time.Second) c.real_connected = true if c.verbosemode { log.GetLogger().Info("set real_connected true") } } var resend_buff []byte for { time.Sleep(time.Duration(c.sendInterval) * time.Millisecond) // if resend_buff not empty, set read timeout avoid blocking forever if c.PortForward && len(resend_buff) > 0 { c.Input.(net.Conn).SetReadDeadline(time.Now().Add(1 * time.Second)) } size, err := br.Read(buff) if c.PortForward { // cancel the read timeout c.Input.(net.Conn).SetReadDeadline(time.Time{}) } if err != nil { if netErr, ok := err.(net.Error); ok && netErr.Timeout() { if c.verbosemode { log.GetLogger().Warnln("read from local conn timeout:", err) } } else { log.GetLogger().Errorf("get raw input failed: %v", err) // tell agent to close session log.GetLogger().Infoln("local conn closed, send CloseMessage") if err = c.SendCloseMessage(); err != nil { log.GetLogger().Errorf("SendCloseMessage err: %v", err) } return openPoison(fname, c.poison) } } if size == 0 && len(resend_buff) == 0 { continue } data := buff[:size] if c.real_connected == true { if len(resend_buff) > 0 { time.Sleep(time.Duration(100) * time.Millisecond) c.SendStreamDataMessage(resend_buff) log.GetLogger().Infoln("agent ready resend user input:", string(resend_buff), len(resend_buff)) resend_buff = nil } err = c.SendStreamDataMessage(data) if err != nil { return openPoison(fname, c.poison) } if c.verbosemode { log.GetLogger().Infoln("send user input:", string(data), size) } } else { if len(resend_buff) == 0 { resend_buff = make([]byte, size) copy(resend_buff, buff[:size]) log.GetLogger().Infoln("store user input:", string(data), size) } } } return 0 } func (c *Client) writeLoop(wg *sync.WaitGroup) int { defer wg.Done() fname := "writeLoop" buff := make([]byte, 2048) rdfs := &goselect.FDSet{} reader := io.ReadCloser(c.Input) pr := NewEscapeProxy(reader, c.EscapeKeys) defer reader.Close() if c.idleTimeout > 0 { log.GetLogger().Infof("Idle timeout %d seconds.", c.idleTimeout) go func() { // Init timer as 1 seconds, timer := time.NewTimer(time.Second) defer timer.Stop() for { select { case <-c.poison: return case <-timer.C: elapsedTime := int32(time.Since(c.startTimestamp).Seconds()) - c.lastDataTimestampOffset.Load() if elapsedTime >= c.idleTimeout { // Idle for too long log.GetLogger().Infoln("Idle for too long, close client.") c.SendCloseMessage() openPoison(fname, c.poison) } timer.Reset(time.Duration(c.idleTimeout - elapsedTime) * time.Second) } } }() go func() { // minimumIdleTimeout is 60s, so the period of sending keep-alive // package is set to 60s, it does not need to be set too small. // But it cannot exceed 180 seconds, because the Agent will // disconnect if agent does not receive a data packet within 180s. timer := time.NewTicker(time.Minute) defer timer.Stop() for { select { case <-c.poison: return case <-timer.C: if err := c.SendKeepAliveDataMessage(); err != nil { log.GetLogger().Error("Send keep alive package failed: ", err) } } } }() } for { select { case <-c.poison: return die(fname, c.poison) default: } rdfs.Zero() rdfs.Set(reader.(exposeFd).Fd()) err := goselect.Select(1, rdfs, nil, nil, 50*time.Millisecond) if err != nil { // log.GetLogger().Errorf("get raw input failed: %v", err) continue // return openPoison(fname, c.poison) } if rdfs.IsSet(reader.(exposeFd).Fd()) { size, err := pr.Read(buff) if err != nil { log.GetLogger().Infoln("err in input empty") if err == io.EOF { log.GetLogger().Infoln("EOF in input empty") // Send EOF to GoTTY // Send 'Input' marker, as defined in GoTTY::client_context.go, // followed by EOT (a translation of Ctrl-D for terminals) err = c.SendStreamDataMessage((append([]byte{}, byte(4)))) return openPoison(fname, c.poison) continue } else { log.GetLogger().Errorln("err in input empty", err) return openPoison(fname, c.poison) } } if size <= 0 { log.GetLogger().Infoln("user input empty") continue } data := buff[:size] if c.verbosemode { log.GetLogger().Infoln("begin send user input:", string(data), size) } err = c.SendStreamDataMessage(data) if err != nil { return openPoison(fname, c.poison) } } } } func (c *Client) SendStreamDataMessage(inputData []byte) (err error) { if len(inputData) == 0 { log.GetLogger().Debugf("Ignoring empty stream data payload.") return nil } agentMessage := &message.Message{ MessageType: message.InputStreamDataMessage, SchemaVersion: "1.01", CreatedDate: uint64(time.Now().UnixNano() / 1000000), SequenceNumber: c.StreamDataSequenceNumber, PayloadLength: uint32(len(inputData)), Payload: inputData, } if c.verbosemode { log.GetLogger().Infoln("SendStreamDataMessage num: ", c.StreamDataSequenceNumber) } msg, err := agentMessage.Serialize() if err != nil { return fmt.Errorf("cannot serialize StreamData message %v", agentMessage) } if err = c.sendMessage(msg, websocket.BinaryMessage); err != nil { log.GetLogger().Errorf("Error sending stream data message %v", err) log.GetLogger().Infoln("disconnect, plugin exit") // os.Exit(1) c.Connected = false return err } c.lastDataTimestampOffset.Store(int32(time.Since(c.startTimestamp).Seconds())) if c.verbosemode { log.GetLogger().Println("SendStreamDataMessage:", msg) } c.StreamDataSequenceNumber = c.StreamDataSequenceNumber + 1 return nil } func (c *Client) SendKeepAliveDataMessage() (err error) { agentMessage := &message.Message{ MessageType: message.InputStreamDataMessage, SchemaVersion: "1.01", CreatedDate: uint64(time.Now().UnixNano() / 1000000), SequenceNumber: c.StreamDataSequenceNumber, PayloadLength: 0, Payload: []byte{}, } if c.verbosemode { log.GetLogger().Infoln("SendKeepAliveDataMessage num: ", c.StreamDataSequenceNumber) } msg, err := agentMessage.Serialize() if err != nil { return fmt.Errorf("cannot serialize StreamData message %v, %v", agentMessage, err) } if err = c.sendMessage(msg, websocket.BinaryMessage); err != nil { log.GetLogger().Errorf("Error sending keep alive message %v", err) return err } if c.verbosemode { log.GetLogger().Println("SendKeepAliveDataMessage:", msg) } c.StreamDataSequenceNumber = c.StreamDataSequenceNumber + 1 return nil } func (c *Client) SendCloseMessage() (err error) { inputData := []byte("1") agentMessage := &message.Message{ MessageType: message.CloseDataChannel, SchemaVersion: "1.01", CreatedDate: uint64(time.Now().UnixNano() / 1000000), SequenceNumber: c.StreamDataSequenceNumber, PayloadLength: uint32(len(inputData)), Payload: inputData, } if c.verbosemode { log.GetLogger().Infoln("SendCloseMessage num: ", c.StreamDataSequenceNumber) } msg, err := agentMessage.Serialize() if err != nil { return fmt.Errorf("cannot serialize StreamData message %v", agentMessage) } if err = c.sendMessage(msg, websocket.BinaryMessage); err != nil { log.GetLogger().Errorf("Error sending stream data message %v", err) log.GetLogger().Infoln("disconnect, plugin exit") // os.Exit(1) c.Connected = false return err } log.GetLogger().Infoln("SendCloseMessage") c.StreamDataSequenceNumber = c.StreamDataSequenceNumber + 1 return nil } func (c *Client) SendResizeDataMessage(inputData []byte) (err error) { if len(inputData) == 0 { log.GetLogger().Debugf("Ignoring empty stream data payload.") return nil } agentMessage := &message.Message{ MessageType: message.SetSizeDataMessage, SchemaVersion: "1.01", CreatedDate: uint64(time.Now().UnixNano() / 1000000), SequenceNumber: c.StreamDataSequenceNumber, PayloadLength: uint32(len(inputData)), Payload: inputData, } msg, err := agentMessage.Serialize() if err != nil { log.GetLogger().Errorf("cannot serialize StreamData message %v", agentMessage) return fmt.Errorf("cannot serialize StreamData message %v", agentMessage) } if err = c.sendMessage(msg, websocket.BinaryMessage); err != nil { log.GetLogger().Errorf("Error sending stream data message %v", err) return err } c.lastDataTimestampOffset.Store(int32(time.Since(c.startTimestamp).Seconds())) c.StreamDataSequenceNumber = c.StreamDataSequenceNumber + 1 return nil } func (c *Client) sendMessage(input []byte, inputType int) error { defer func() { if msg := recover(); msg != nil { log.GetLogger().Errorf("WebsocketChannel run panic: %v", msg) log.GetLogger().Errorf("%s: %s", msg, debug.Stack()) } }() if len(input) < 1 { log.GetLogger().Errorln("Can't send message: Empty input.") return errors.New("Can't send message: Empty input.") } c.WriteMutex.Lock() err := c.Conn.WriteMessage(inputType, input) if c.verbosemode { log.GetLogger().Infoln("begin send msg: ", string(input)) } if err != nil { log.GetLogger().Errorf("send messagefaile, %v", err) } c.WriteMutex.Unlock() return err } func openPoison(fname string, poison chan bool) int { logrus.Debug(fname + " suicide") /* * The close() may raise panic if multiple goroutines commit suicide at the * same time. Prevent that panic from bubbling up. */ defer func() { if r := recover(); r != nil { logrus.Debug("Prevented panic() of simultaneous suicides", r) } }() /* Signal others to die */ close(poison) return committedSuicide } func die(fname string, poison chan bool) int { logrus.Debug(fname + " died") wasOpen := <-poison if wasOpen { logrus.Error("ERROR: The channel was open when it wasn't supposed to be") } return killed }