agent/channel/channel_websocket.go (268 lines of code) (raw):
package channel
import (
"crypto/tls"
"crypto/x509"
"errors"
"fmt"
"net/http"
"runtime/debug"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/aliyun/aliyun_assist_client/agent/clientreport"
"github.com/aliyun/aliyun_assist_client/agent/log"
"github.com/aliyun/aliyun_assist_client/agent/metrics"
"github.com/aliyun/aliyun_assist_client/agent/util"
_ "github.com/aliyun/aliyun_assist_client/common/apiserver"
"github.com/aliyun/aliyun_assist_client/common/httpbase"
"github.com/aliyun/aliyun_assist_client/common/requester"
)
const (
WEBSOCKET_SERVER = "/luban/notify_server"
MAX_RETRY_COUNT = 5
)
var (
wssCoolDownCount = 1 // limit of continuous failed connection
wssCoolDownTime = 60 // second
)
type WebSocketChannel struct {
*Channel
wskConn *websocket.Conn
lock sync.Mutex
writeLock sync.Mutex
consecutiveConnectFailed int
calmDownUntil time.Time
}
func (c *WebSocketChannel) IsSupported() bool {
host := util.GetServerHost()
if host == "" {
metrics.GetChannelFailEvent(
metrics.EVENT_SUBCATEGORY_CHANNEL_WS,
"errormsg", "websocket channel not supported",
"type", ChannelTypeStr(c.ChannelType),
).ReportEvent()
log.GetLogger().Error("websocket channel not supported")
return false
}
return true
}
func (c *WebSocketChannel) StartChannel() error {
c.lock.Lock()
defer c.lock.Unlock()
errmsg := ""
defer func() {
if len(errmsg) > 0 {
metrics.GetChannelFailEvent(
metrics.EVENT_SUBCATEGORY_CHANNEL_WS,
"errmsg", errmsg,
"type", ChannelTypeStr(c.ChannelType),
).ReportEvent()
}
}()
if c.consecutiveConnectFailed >= wssCoolDownCount {
if time.Now().Before(c.calmDownUntil) {
return errors.New("ws channel is calming down")
}
c.consecutiveConnectFailed = 0
log.GetLogger().Info("ws channel is not calm anymore")
}
host := util.GetServerHost()
if host == "" {
errmsg = "No available host"
return errors.New("No available host")
}
url := "wss://" + host + WEBSOCKET_SERVER
logger := log.GetLogger().WithField("url", url)
header := http.Header{
httpbase.UserAgentHeader: []string{httpbase.UserAgentValue},
}
if extraHeaders, err := requester.GetExtraHTTPHeaders(logger); extraHeaders != nil {
for k, v := range extraHeaders {
header.Add(k, v)
}
} else if err != nil {
logger.WithError(err).Error("Failed to construct extra HTTP headers")
}
var MyDialer = &websocket.Dialer{
Proxy: requester.GetProxyFunc(logger),
HandshakeTimeout: 45 * time.Second,
TLSClientConfig: &tls.Config{
RootCAs: requester.GetRootCAs(logger),
},
}
var dialErr error
var conn *websocket.Conn
conn, _, dialErr = MyDialer.Dial(url, header)
if dialErr != nil {
var certificateErr *tls.CertificateVerificationError
if errors.As(dialErr, &certificateErr) {
logger.WithError(dialErr).Error("certificate error, reload certificate and retry")
requester.AccumulateRootCAs(logger)(func(certPool *x509.CertPool) bool {
MyDialer.TLSClientConfig.RootCAs = certPool
if conn, _, dialErr = MyDialer.Dial(url, header); dialErr != nil {
errmsg = fmt.Sprintf("dial ws channel errror:%s, url=%s", dialErr.Error(), url)
return true
} else {
requester.UpdateRootCAs(logger, certPool)
logger.Info("certificate updated")
return false
}
})
} else {
errmsg = fmt.Sprintf("dial ws channel errror:%s, url=%s", dialErr.Error(), url)
}
}
if dialErr != nil {
c.consecutiveConnectFailed += 1
if c.consecutiveConnectFailed >= wssCoolDownCount {
c.calmDownUntil = time.Now().Add(time.Second * time.Duration(wssCoolDownTime))
errmsg = fmt.Sprintf("dial ws channel errror:%s, url=%s, wss dial failed %d times "+
"consecutivly, need calm down %d second",
dialErr.Error(), url, c.consecutiveConnectFailed, wssCoolDownTime)
} else {
errmsg = fmt.Sprintf("dial ws channel errror:%s, url=%s", dialErr.Error(), url)
}
log.GetLogger().Errorln(dialErr)
return dialErr
}
c.consecutiveConnectFailed = 0
c.wskConn = conn
logger.Infoln("Start websocket channel ok! url:", url)
c.Working.Set()
c.StartPings(time.Second * 60)
go func() {
defer func() {
if msg := recover(); msg != nil {
logger.Errorf("WebsocketChannel run panic: %v", msg)
logger.Errorf("%s: %s", msg, debug.Stack())
}
}()
retryCount := 0
for {
if !c.Working.IsSet() {
logger.Infoln("websocket channel is closed")
break
}
messageType, message, err := c.wskConn.ReadMessage()
if err != nil {
time.Sleep(time.Duration(1) * time.Second)
retryCount++
if retryCount >= MAX_RETRY_COUNT {
c.lock.Lock()
defer c.lock.Unlock()
c.wskConn.Close()
c.Working.Clear()
logger.Errorf("Reach the retry limit for receive messages. Error: %v", err.Error())
report := clientreport.ClientReport{
ReportType: "switch_channel_in_wsk",
Info: fmt.Sprintf("start:" + err.Error()),
}
clientreport.SendReport(report)
go c.SwitchChannel()
break
}
logger.Errorf(
"An error happened when receiving the message. Retried times: %d, MessageType: %v, Error: %s",
retryCount,
messageType,
err.Error())
} else if messageType != websocket.TextMessage && messageType != websocket.BinaryMessage {
logger.Errorf("Invalid message type %d. ", messageType)
} else {
logger.Infof("wsk recv: %s", string(message))
content := c.CallBack(string(message), ChannelWebsocketType)
if content != "" {
c.writeLock.Lock()
err := c.wskConn.WriteMessage(websocket.TextMessage, []byte(content))
c.writeLock.Unlock()
if err != nil {
metrics.GetChannelFailEvent(
metrics.EVENT_SUBCATEGORY_CHANNEL_WS,
"errormsg", fmt.Sprintf("websocket writing err:%s, content=%s", err.Error(), content),
"type", ChannelTypeStr(c.ChannelType),
).ReportEvent()
}
}
retryCount = 0
}
}
}()
return nil
}
func (c *WebSocketChannel) SwitchChannel() error {
time.Sleep(time.Duration(1) * time.Second)
for i := 0; i < 5; i++ {
if err := G_ChannelMgr.SelectAvailableChannelAndReport(ChannelNone, "switch_channel_in_wsk", true); err == nil {
return nil
}
time.Sleep(time.Duration(5) * time.Second)
}
metrics.GetChannelSwitchEvent(
"type", ChannelTypeStr(G_ChannelMgr.GetCurrentChannelType()),
"reportType", "switch_channel_in_wsk",
"info", fmt.Sprintf("fail: no available channel"),
).ReportEvent()
report := clientreport.ClientReport{
ReportType: "switch_channel_in_wsk",
Info: fmt.Sprintf("fail: no available channel"),
}
clientreport.SendReport(report)
return errors.New("no available channel")
}
func (c *WebSocketChannel) StopChannel() error {
c.lock.Lock()
defer c.lock.Unlock()
if c.Working.IsSet() {
c.Working.Clear()
log.GetLogger().Println("close websocket channel")
err := c.wskConn.Close()
if err != nil {
metrics.GetChannelFailEvent(
metrics.EVENT_SUBCATEGORY_CHANNEL_WS,
"errormsg", fmt.Sprintf("close websocket channel error:%s", err),
"type", ChannelTypeStr(c.ChannelType),
).ReportEvent()
log.GetLogger().Println("close websocket channel error:", err)
}
}
return nil
}
func (c *WebSocketChannel) StartPings(pingInterval time.Duration) {
go func() {
for {
if !c.Working.IsSet() {
return
}
log.GetLogger().Infoln("WebsocketChannel: ping...")
c.writeLock.Lock()
err := c.wskConn.WriteMessage(websocket.PingMessage, []byte("keepalive"))
c.writeLock.Unlock()
if err != nil {
metrics.GetChannelFailEvent(
metrics.EVENT_SUBCATEGORY_CHANNEL_WS,
"errormsg", fmt.Sprintf("Error while sending websocket ping: %s", err.Error()),
"type", ChannelTypeStr(c.ChannelType),
).ReportEvent()
log.GetLogger().Errorf("Error while sending websocket ping: %v", err)
return
}
time.Sleep(pingInterval)
}
}()
}
func (c *WebSocketChannel) ResetFailedCount() {
c.lock.Lock()
defer c.lock.Unlock()
log.GetLogger().Info("ws channel: reset fail count")
c.consecutiveConnectFailed = 0
}
func NewWebsocketChannel(CallBack OnReceiveMsg) *WebSocketChannel {
w := &WebSocketChannel{
Channel: &Channel{
CallBack: CallBack,
ChannelType: ChannelWebsocketType,
},
}
w.Working.Clear()
return w
}