websocketserver/websocketserver.go (238 lines of code) (raw):
package websocketserver
import (
"encoding/json"
"fmt"
"net"
"net/http"
"net/url"
"time"
"github.com/gorilla/websocket"
"github.com/sirupsen/logrus"
"gitlab.com/gitlab-org/webide-file-sync/common"
)
type namespace struct {
connectHandler connectFunc
disconnectHandler disconnectFunc
eventHandlers handlers
}
type Server struct {
conn *websocket.Conn
namespaces map[string]*namespace
ticker *time.Ticker
pongWait time.Duration
validOrigins []string
}
const (
DefaultPongWait = 30 * time.Second
)
var upgrader = websocket.Upgrader{
EnableCompression: true,
CheckOrigin: func(r *http.Request) bool { return true },
}
// NewServer returns a new websocket Server
func NewServer(origins []string, pongWait time.Duration) (*Server, error) {
if pongWait == 0 {
pongWait = DefaultPongWait
}
pingPeriod := (pongWait * 9) / 10
s := Server{ticker: time.NewTicker(pingPeriod),
pongWait: pongWait,
namespaces: make(map[string]*namespace)}
for _, origin := range origins {
if origin == "" {
continue
}
s.validOrigins = append(s.validOrigins, common.NormalizedHostname(origin))
}
return &s, nil
}
func (s *Server) SetOrigins(origins []string) {
s.validOrigins = origins
}
// Serve handles new incoming websocket connection requests
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
u, err := url.Parse(r.Header.Get("Origin"))
if err != nil || !s.isValidOrigin(u.Host) {
message := u.Host
if err != nil {
message = err.Error()
}
logrus.Error("invalid origin: ", message)
http.Error(w, "invalid origin", http.StatusBadRequest)
return
}
if s.conn != nil {
http.Error(w, http.StatusText(http.StatusTooManyRequests), http.StatusTooManyRequests)
return
}
conn, err := handleConnection(w, r)
if err != nil {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
s.setupConnection(conn)
s.startPolling()
defer s.Close()
for {
if s.conn == nil {
break
}
msg, err := s.readMessage()
if err != nil {
break
}
rsp, err := s.processMessage(msg)
if err != nil {
break
}
if rsp != nil {
if err := s.WriteMessage(*rsp); err != nil {
logrus.Error("error writing message: ", err)
break
}
}
}
}
func (s *Server) isValidOrigin(host string) bool {
if len(s.validOrigins) == 0 {
return true
}
nHost := common.NormalizedHostname(host)
for _, validOrigin := range s.validOrigins {
if validOrigin == "*" || nHost == validOrigin {
return true
}
}
return false
}
func handleConnection(w http.ResponseWriter, r *http.Request) (*websocket.Conn, error) {
var err error
// Setting the same protocol received for the websocket connection
respHeader := make(http.Header)
respHeader.Add("Sec-WebSocket-Protocol", "webfilesync.gitlab.com")
conn, err := upgrader.Upgrade(w, r, respHeader)
if err != nil {
logrus.Error("error upgrading connection: ", err)
return nil, err
}
return conn, nil
}
func (s *Server) readMessage() (common.Message, error) {
var msg common.Message
err := s.conn.ReadJSON(&msg)
if err != nil {
logrus.Error("error reading message: ", err)
s.handleReadError(err)
}
return msg, err
}
// If the error is already a websocket close error we don't need to notify the client,
// it will be automatically handled.
// When the error is because of the JSON decoding or deadline exceeded,
// we send a custom error to the client and close the connection
func (s *Server) handleReadError(err error) {
var closeMessage []byte
switch err.(type) {
case net.Error:
netError, _ := err.(net.Error)
if netError.Timeout() {
closeMessage = websocket.FormatCloseMessage(websocket.CloseGoingAway, err.Error())
}
case *websocket.CloseError:
// NOOP
default:
closeMessage = websocket.FormatCloseMessage(websocket.CloseUnsupportedData, "invalid message")
}
if closeMessage != nil && s.conn != nil {
s.conn.WriteControl(websocket.CloseMessage, closeMessage, time.Now().Add(3*time.Second))
}
}
func (s *Server) setupConnection(conn *websocket.Conn) {
s.conn = conn
s.conn.SetCloseHandler(s.closeHandler)
s.setReadDeadline(s.pongWait)
s.conn.SetPongHandler(s.pongHandler)
}
func (s *Server) startPolling() {
go func() {
for range s.ticker.C {
if s.conn == nil {
return
}
if err := s.conn.WriteMessage(websocket.PingMessage, nil); err != nil {
logrus.Error("error sending ping message: ", err)
return
}
}
}()
}
func (s *Server) stopPolling() {
s.ticker.Stop()
}
// WriteMessage sends a transport message to the client
func (s *Server) WriteMessage(msg common.Message) error {
if msg.Code != common.Ping {
msg.Code = common.Ack
} else {
msg.Code = common.Pong
msg.Payload = nil
}
return s.conn.WriteJSON(msg)
}
// WriteErrorMessage sends a transport error message to the client
func (s *Server) WriteErrorMessage(msg common.Message, code int, message string) error {
msg.Error = common.ErrorMessage{Code: code, Message: message}
msg.Payload = nil
return s.WriteMessage(msg)
}
func (s *Server) processMessage(msg common.Message) (*common.Message, error) {
if msg.Code == common.Ping {
return &msg, nil
}
result, err := s.handleMessageType(msg)
if err != nil {
logrus.Error("error handling the received message: ", err)
// Try to send a message with the error
err := s.WriteErrorMessage(msg, http.StatusInternalServerError, err.Error())
// if the message cannot be sent, then the connection is closed
if err != nil {
logrus.Error("error sending error message: ", err)
return nil, err
}
// we had an error handling the message and we notified the user,
// we can start reading messages again
return nil, nil
}
var payload json.RawMessage
if result != nil {
payload, err = json.Marshal(result)
if err != nil {
logrus.Error("error encoding response: ", err)
err := s.WriteErrorMessage(msg, http.StatusInternalServerError, "error encoding response")
if err != nil {
return nil, err
}
return nil, nil
}
msg.Payload = payload
}
return &msg, nil
}
func (s *Server) handleMessageType(msg common.Message) (*common.Response, error) {
n := s.namespaces[msg.Namespace]
if n == nil {
return nil, fmt.Errorf("namespace %s is not registered", msg.Namespace)
}
switch msg.Code {
case common.Connect:
if n.connectHandler == nil {
return nil, nil
}
return n.connectHandler(s.conn), nil
case common.Disconnect:
if n.disconnectHandler == nil {
return nil, nil
}
return n.disconnectHandler(s.conn), nil
case common.Event:
return s.callEventHandler(msg.EventName, n, msg.Payload)
default:
return nil, fmt.Errorf("invalid message")
}
}
// Close stops the current websocket connection and also stops the connection polling
func (s *Server) Close() {
if s.conn != nil {
s.conn.Close()
}
s.stopPolling()
s.conn = nil
}
func (s *Server) setReadDeadline(t time.Duration) {
s.conn.SetReadDeadline(time.Now().Add(t))
}