agent/session/port/port.go (250 lines of code) (raw):
package port
import (
"context"
"errors"
"fmt"
"net"
"strconv"
"time"
"github.com/aliyun/aliyun_assist_client/agent/log"
"github.com/aliyun/aliyun_assist_client/agent/session/channel"
"github.com/aliyun/aliyun_assist_client/agent/session/message"
"github.com/aliyun/aliyun_assist_client/agent/session/shell"
"github.com/aliyun/aliyun_assist_client/agent/util"
"github.com/aliyun/aliyun_assist_client/thirdparty/sirupsen/logrus"
"go.uber.org/atomic"
)
const (
Ok = "Ok"
Open_port_failed = "Open_port_failed"
Read_port_failed = "Read_port_failed"
IO_socket_error = "IO_port_failed"
Unknown_error = "Unknown_error"
)
const (
sendPackageSize = 2048 // 发送的payload大小上限,单位 B
defaultSendSpeed = 200 // 默认的最大数据发送速率,单位 kbps
defaultSendInterval = 1000 / (defaultSendSpeed * 1024 / 8 / sendPackageSize) // writeloop的循环间隔时间 单位ms
waitLocalConnTimeoutSecond = 10
)
type PortPlugin struct {
id string
targetHost string
portNumber int
dataChannel channel.ISessionChannel
conn net.Conn
connReady chan struct{}
sendInterval int
needReconnect atomic.Bool
reconnectSign chan struct{}
exitCtx context.Context
exitFunc context.CancelCauseFunc
logger logrus.FieldLogger
}
func NewPortPlugin(id string, targetHost string, portNumber int, flowLimit int) *PortPlugin {
if targetHost == "" {
targetHost = "localhost"
}
plugin := &PortPlugin{
id: id,
targetHost: targetHost,
portNumber: portNumber,
sendInterval: defaultSendInterval,
logger: log.GetLogger().WithFields(logrus.Fields{
"sessionType": "portforward",
"sessionId": id,
}),
}
plugin.connReady = make(chan struct{})
plugin.reconnectSign = make(chan struct{})
if flowLimit > 0 {
plugin.sendInterval = 1000 / (flowLimit / 8 / sendPackageSize)
} else {
flowLimit = defaultSendSpeed * 1024
}
plugin.logger.Infof("Init send speed, channelId[%s] speed[%d]bps sendInterval[%d]ms\n", id, flowLimit, plugin.sendInterval)
return plugin
}
func (p *PortPlugin) Stop() {
if p.conn == nil {
return
}
err := p.conn.Close()
p.logger.WithError(err).Info("Close local connection")
}
func (p *PortPlugin) Execute(dataChannel channel.ISessionChannel, cancelFlag util.CancelFlag) (errorCode string, pluginErr error) {
p.dataChannel = dataChannel
// p.exitCtx and p.exitFunc are initialized in Execute() instead of NewPortPlugin() to ensure
// that p.exitCtx can be released eventually.
// Although InputStreamMessageHandler() may be called before Execute(), it will wait until
// p.conn is successfully created before starting to process data (possibly calling p.exitFunc),
// so p.exitFunc will not be called before it be initialized.
p.exitCtx, p.exitFunc = context.WithCancelCause(context.Background())
defer func() {
p.logger.Infoln("stop in run PortPlugin")
p.Stop()
if err := recover(); err != nil {
p.logger.Errorf("Error occurred while executing port plugin %s: \n%v", p.id, err)
// Panic in session port plugin SHOULD NOT disturb the whole agent
// process
errorCode = Unknown_error
if v, ok := err.(error); ok {
pluginErr = v
} else {
pluginErr = fmt.Errorf(fmt.Sprint(err))
}
}
}()
targetHostPort := net.JoinHostPort(p.targetHost, strconv.Itoa(p.portNumber))
p.logger.Infoln("start port, dial tcp connection to ", targetHostPort)
if p.conn, pluginErr = net.DialTimeout("tcp", targetHostPort, time.Second*waitLocalConnTimeoutSecond); pluginErr != nil {
errorString := fmt.Sprintf("Unable to start port: %s", pluginErr)
p.logger.Errorln(errorString)
errorCode = Open_port_failed
return
}
close(p.connReady)
p.logger.Infoln("start port success")
go func() {
select {
case <-cancelFlag.C():
cancelState := cancelFlag.State()
if cancelState == util.Canceled {
p.exitFunc(errors.New(shell.Timeout))
} else {
p.exitFunc(errors.New(shell.Notified))
}
p.logger.Debugf("Cancel flag set to %v in session", cancelState)
case <-p.exitCtx.Done():
cancelFlag.Set(util.ShutDown)
}
}()
go func() {
p.writePump()
}()
p.logger.Infof("Plugin %s started", p.id)
<-p.exitCtx.Done()
errorCode = context.Cause(p.exitCtx).Error()
p.logger.Infoln("Plugin done", p.id, errorCode)
return
}
func (p *PortPlugin) writePump() {
defer func() {
p.logger.Info("writePump done")
if err := recover(); err != nil {
p.logger.Infoln("WritePump thread crashed with message ", err)
fmt.Println("WritePump thread crashed with message: \n", err)
}
}()
packet := make([]byte, sendPackageSize)
for {
select {
case <-p.exitCtx.Done():
return
default:
if p.dataChannel.IsActive() == true {
numBytes, err := p.conn.Read(packet)
if err != nil {
if connected := p.onError(err); connected {
p.logger.Infoln("Reconnection to port is successful, resume reading from port.")
continue
}
p.logger.Infof("Unable to read port: %v", err)
return
}
if util.IsVerboseMode() {
p.logger.Infoln("read data:", string(packet[:numBytes]))
}
if err = p.dataChannel.SendStreamDataMessage(packet[:numBytes]); err != nil {
p.logger.Errorf("Unable to send stream data message: %v", err)
p.exitFunc(errors.New(IO_socket_error))
return
}
} else {
p.logger.Infoln("PortPlugin:writePump stream is closed")
p.exitFunc(errors.New(IO_socket_error))
return
}
// Wait for TCP to process more data
time.Sleep(time.Duration(p.sendInterval) * time.Millisecond)
}
}
}
func (p *PortPlugin) onError(theErr error) (reconnected bool) {
p.logger.Infoln("Encountered reconnect while reading from port: ", theErr)
select {
case <-p.exitCtx.Done():
reconnected = false
return
default:
}
p.Stop()
p.needReconnect.Store(true)
p.logger.Debugf("Waiting for reconnection to port!!")
select {
case <-p.reconnectSign:
reconnected = true
case <-p.exitCtx.Done():
reconnected = false
}
return
}
func (p *PortPlugin) InputStreamMessageHandler(streamDataMessage message.Message) error {
if p.conn == nil {
p.logger.Infof("InputStreamMessageHandler: connect not ready, wait %d seconds...", waitLocalConnTimeoutSecond)
timer := time.NewTimer(time.Duration(waitLocalConnTimeoutSecond) * time.Second)
select {
case <-p.connReady:
p.logger.Infoln("InputStreamMessageHandler: connect is ready")
case <-timer.C:
p.logger.Infof("InputStreamMessageHandler: connect still not ready after %d seconds", waitLocalConnTimeoutSecond)
return fmt.Errorf("connection with target host port not ready")
}
}
switch streamDataMessage.MessageType {
case message.InputStreamDataMessage:
// Reconnect only when receive data message
if p.needReconnect.CompareAndSwap(true, false) {
targetHostPort := net.JoinHostPort(p.targetHost, strconv.Itoa(p.portNumber))
p.logger.Infof("InputStreamMessageHandler:Reconnect to %s", targetHostPort)
var err error
p.conn, err = net.Dial("tcp", targetHostPort)
if err != nil {
p.logger.WithError(err).Error("Reconnect failed")
p.exitFunc(errors.New(Read_port_failed))
return err
} else {
p.logger.Info("Reconnect succeed")
p.reconnectSign <- struct{}{}
}
}
if _, err := p.conn.Write(streamDataMessage.Payload); err != nil {
p.logger.Errorf("Unable to write to port, err: %v.", err)
return err
}
if util.IsVerboseMode() {
p.logger.Infoln("write data:", string(streamDataMessage.Payload))
}
case message.StatusDataMessage:
p.logger.Info("message type: ", streamDataMessage.MessageType)
if len(streamDataMessage.Payload) > 0 {
code, err := message.BytesToIntU(streamDataMessage.Payload[0:1])
p.logger.WithError(err).Info("message status code: ", code)
if err == nil {
switch code {
case 7: // 设置agent的发送速率
speed, err := message.BytesToIntU(streamDataMessage.Payload[1:]) // speed 单位是 bps
if speed == 0 {
break
}
if err != nil {
p.logger.Errorf("Invalid flowLimit: %s", err)
return err
}
p.sendInterval = 1000 / (speed / 8 / sendPackageSize)
p.logger.Infof("Set send speed, channelId[%s] speed[%d]bps sendInterval[%d]ms\n", p.id, speed, p.sendInterval)
case 5:
p.logger.Info("Exit due to receiving a packet with a close status")
p.exitFunc(errors.New(Ok))
}
} else {
p.logger.Errorf("Parse status code err: %s", err)
}
}
case message.CloseDataChannel:
p.logger.Info("Exit due to receiving CloseDataChannel packet")
p.exitFunc(errors.New(Ok))
}
return nil
}