websocketserver/websocketserver.go (238 lines of code) (raw):

package websocketserver import ( "encoding/json" "fmt" "net" "net/http" "net/url" "time" "github.com/gorilla/websocket" "github.com/sirupsen/logrus" "gitlab.com/gitlab-org/webide-file-sync/common" ) type namespace struct { connectHandler connectFunc disconnectHandler disconnectFunc eventHandlers handlers } type Server struct { conn *websocket.Conn namespaces map[string]*namespace ticker *time.Ticker pongWait time.Duration validOrigins []string } const ( DefaultPongWait = 30 * time.Second ) var upgrader = websocket.Upgrader{ EnableCompression: true, CheckOrigin: func(r *http.Request) bool { return true }, } // NewServer returns a new websocket Server func NewServer(origins []string, pongWait time.Duration) (*Server, error) { if pongWait == 0 { pongWait = DefaultPongWait } pingPeriod := (pongWait * 9) / 10 s := Server{ticker: time.NewTicker(pingPeriod), pongWait: pongWait, namespaces: make(map[string]*namespace)} for _, origin := range origins { if origin == "" { continue } s.validOrigins = append(s.validOrigins, common.NormalizedHostname(origin)) } return &s, nil } func (s *Server) SetOrigins(origins []string) { s.validOrigins = origins } // Serve handles new incoming websocket connection requests func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { u, err := url.Parse(r.Header.Get("Origin")) if err != nil || !s.isValidOrigin(u.Host) { message := u.Host if err != nil { message = err.Error() } logrus.Error("invalid origin: ", message) http.Error(w, "invalid origin", http.StatusBadRequest) return } if s.conn != nil { http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests) return } conn, err := handleConnection(w, r) if err != nil { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } s.setupConnection(conn) s.startPolling() defer s.Close() for { if s.conn == nil { break } msg, err := s.readMessage() if err != nil { break } rsp, err := s.processMessage(msg) if err != nil { break } if rsp != nil { if err := s.WriteMessage(*rsp); err != nil { logrus.Error("error writing message: ", err) break } } } } func (s *Server) isValidOrigin(host string) bool { if len(s.validOrigins) == 0 { return true } nHost := common.NormalizedHostname(host) for _, validOrigin := range s.validOrigins { if validOrigin == "*" || nHost == validOrigin { return true } } return false } func handleConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) { var err error // Setting the same protocol received for the websocket connection respHeader := make(http.Header) respHeader.Add("Sec-WebSocket-Protocol", "webfilesync.gitlab.com") conn, err := upgrader.Upgrade(w, r, respHeader) if err != nil { logrus.Error("error upgrading connection: ", err) return nil, err } return conn, nil } func (s *Server) readMessage() (common.Message, error) { var msg common.Message err := s.conn.ReadJSON(&msg) if err != nil { logrus.Error("error reading message: ", err) s.handleReadError(err) } return msg, err } // If the error is already a websocket close error we don't need to notify the client, // it will be automatically handled. // When the error is because of the JSON decoding or deadline exceeded, // we send a custom error to the client and close the connection func (s *Server) handleReadError(err error) { var closeMessage []byte switch err.(type) { case net.Error: netError, _ := err.(net.Error) if netError.Timeout() { closeMessage = websocket.FormatCloseMessage(websocket.CloseGoingAway, err.Error()) } case *websocket.CloseError: // NOOP default: closeMessage = websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "invalid message") } if closeMessage != nil && s.conn != nil { s.conn.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(3*time.Second)) } } func (s *Server) setupConnection(conn *websocket.Conn) { s.conn = conn s.conn.SetCloseHandler(s.closeHandler) s.setReadDeadline(s.pongWait) s.conn.SetPongHandler(s.pongHandler) } func (s *Server) startPolling() { go func() { for range s.ticker.C { if s.conn == nil { return } if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil { logrus.Error("error sending ping message: ", err) return } } }() } func (s *Server) stopPolling() { s.ticker.Stop() } // WriteMessage sends a transport message to the client func (s *Server) WriteMessage(msg common.Message) error { if msg.Code != common.Ping { msg.Code = common.Ack } else { msg.Code = common.Pong msg.Payload = nil } return s.conn.WriteJSON(msg) } // WriteErrorMessage sends a transport error message to the client func (s *Server) WriteErrorMessage(msg common.Message, code int, message string) error { msg.Error = common.ErrorMessage{Code: code, Message: message} msg.Payload = nil return s.WriteMessage(msg) } func (s *Server) processMessage(msg common.Message) (*common.Message, error) { if msg.Code == common.Ping { return &msg, nil } result, err := s.handleMessageType(msg) if err != nil { logrus.Error("error handling the received message: ", err) // Try to send a message with the error err := s.WriteErrorMessage(msg, http.StatusInternalServerError, err.Error()) // if the message cannot be sent, then the connection is closed if err != nil { logrus.Error("error sending error message: ", err) return nil, err } // we had an error handling the message and we notified the user, // we can start reading messages again return nil, nil } var payload json.RawMessage if result != nil { payload, err = json.Marshal(result) if err != nil { logrus.Error("error encoding response: ", err) err := s.WriteErrorMessage(msg, http.StatusInternalServerError, "error encoding response") if err != nil { return nil, err } return nil, nil } msg.Payload = payload } return &msg, nil } func (s *Server) handleMessageType(msg common.Message) (*common.Response, error) { n := s.namespaces[msg.Namespace] if n == nil { return nil, fmt.Errorf("namespace %s is not registered", msg.Namespace) } switch msg.Code { case common.Connect: if n.connectHandler == nil { return nil, nil } return n.connectHandler(s.conn), nil case common.Disconnect: if n.disconnectHandler == nil { return nil, nil } return n.disconnectHandler(s.conn), nil case common.Event: return s.callEventHandler(msg.EventName, n, msg.Payload) default: return nil, fmt.Errorf("invalid message") } } // Close stops the current websocket connection and also stops the connection polling func (s *Server) Close() { if s.conn != nil { s.conn.Close() } s.stopPolling() s.conn = nil } func (s *Server) setReadDeadline(t time.Duration) { s.conn.SetReadDeadline(time.Now().Add(t)) }