From b7afa2373a99253704405273afdc6eb1ecf025e3 Mon Sep 17 00:00:00 2001 From: rxdn <29165304+rxdn@users.noreply.github.com> Date: Wed, 13 Sep 2023 17:06:15 +0100 Subject: [PATCH] Rework websockets --- app/http/endpoints/api/ticket/closeticket.go | 2 +- app/http/endpoints/api/ticket/getticket.go | 2 +- .../endpoints/api/ticket/livechat/client.go | 136 ++++++++++++ .../endpoints/api/ticket/livechat/event.go | 32 +++ .../api/ticket/livechat/eventhandler.go | 112 ++++++++++ .../endpoints/api/ticket/livechat/livechat.go | 44 ++++ .../endpoints/api/ticket/livechat/manager.go | 84 ++++++++ app/http/endpoints/root/webchatws.go | 193 ------------------ app/http/server.go | 8 +- cmd/api/main.go | 22 +- frontend/src/views/TicketView.svelte | 14 +- 11 files changed, 431 insertions(+), 218 deletions(-) create mode 100644 app/http/endpoints/api/ticket/livechat/client.go create mode 100644 app/http/endpoints/api/ticket/livechat/event.go create mode 100644 app/http/endpoints/api/ticket/livechat/eventhandler.go create mode 100644 app/http/endpoints/api/ticket/livechat/livechat.go create mode 100644 app/http/endpoints/api/ticket/livechat/manager.go delete mode 100644 app/http/endpoints/root/webchatws.go diff --git a/app/http/endpoints/api/ticket/closeticket.go b/app/http/endpoints/api/ticket/closeticket.go index 1b86a52..02a3207 100644 --- a/app/http/endpoints/api/ticket/closeticket.go +++ b/app/http/endpoints/api/ticket/closeticket.go @@ -55,7 +55,7 @@ func CloseTicket(ctx *gin.Context) { } hasPermission, requestErr := utils.HasPermissionToViewTicket(guildId, userId, ticket) - if err != nil { + if requestErr != nil { ctx.JSON(requestErr.StatusCode, utils.ErrorJson(requestErr)) return } diff --git a/app/http/endpoints/api/ticket/getticket.go b/app/http/endpoints/api/ticket/getticket.go index 3cc5603..b0b67e7 100644 --- a/app/http/endpoints/api/ticket/getticket.go +++ b/app/http/endpoints/api/ticket/getticket.go @@ -64,7 +64,7 @@ func GetTicket(ctx *gin.Context) { } hasPermission, requestErr := utils.HasPermissionToViewTicket(guildId, userId, ticket) - if err != nil { + if requestErr != nil { ctx.JSON(requestErr.StatusCode, utils.ErrorJson(requestErr)) return } diff --git a/app/http/endpoints/api/ticket/livechat/client.go b/app/http/endpoints/api/ticket/livechat/client.go new file mode 100644 index 0000000..54bb13a --- /dev/null +++ b/app/http/endpoints/api/ticket/livechat/client.go @@ -0,0 +1,136 @@ +package livechat + +import ( + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "time" +) + +type Client struct { + Manager *SocketManager + Ws *websocket.Conn + RequestCtx *gin.Context + Authenticated bool + GuildId uint64 + TicketId int + tx chan any + flush chan chan struct{} +} + +const ( + messageSizeLimit = 1024 * 32 + keepaliveFrequency = 45 * time.Second + keepaliveTimeout = 60 * time.Second + writeTimeout = 10 * time.Second +) + +func NewClient(manager *SocketManager, ws *websocket.Conn, c *gin.Context, guildId uint64, ticketId int) *Client { + return &Client{ + Manager: manager, + Ws: ws, + RequestCtx: c, + Authenticated: false, + GuildId: guildId, + TicketId: ticketId, + tx: make(chan any), + flush: make(chan chan struct{}), + } +} + +func (c *Client) Close() { + close(c.tx) +} + +func (c *Client) StartReadLoop() error { + defer func() { + c.Manager.unregister <- c + _ = c.Ws.Close() + c.Close() + }() + + // Set up connection properties + c.Ws.SetReadLimit(messageSizeLimit) + if err := c.Ws.SetReadDeadline(time.Now().Add(keepaliveTimeout)); err != nil { + return err + } + + c.Ws.SetPongHandler(func(appData string) error { + return c.Ws.SetReadDeadline(time.Now().Add(keepaliveTimeout)) + }) + + for { + var event Event + if err := c.Ws.ReadJSON(&event); err != nil { + return err + } + + if !c.Authenticated && event.Type != EventTypeAuth { + if err := c.Ws.WriteJSON(NewErrorMessage("Unauthorized")); err != nil { + return err + } + + return nil + } + + if err := c.HandleEvent(event); err != nil { + c.RequestCtx.Error(err) + c.Write(NewErrorMessage(err.Error())) + c.Flush() + _ = c.Ws.Close() + return err + } + } +} + +func (c *Client) Write(msg any) { + c.tx <- msg +} + +func (c *Client) StartWriteLoop() error { + ticker := time.NewTicker(keepaliveFrequency) + defer func() { + ticker.Stop() + _ = c.Ws.Close() + }() + + for { + select { + case message, ok := <-c.tx: + if err := c.Ws.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return err + } + + if !ok { // Channel was closed + _ = c.Ws.WriteMessage(websocket.CloseMessage, []byte{}) + return nil + } else { + if err := c.Ws.WriteJSON(message); err != nil { + return err + } + } + case <-ticker.C: + if err := c.Ws.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil { + return err + } + + if err := c.Ws.WriteMessage(websocket.PingMessage, nil); err != nil { + return err + } + case ch := <-c.flush: // TODO: Channel order is random, there is a race condition here + ch <- struct{}{} + } + } +} + +func (c *Client) Flush() { + ch := make(chan struct{}) + c.flush <- ch + + timer := time.After(time.Second) + select { + case <-ch: + return + case <-timer: + return + } +} diff --git a/app/http/endpoints/api/ticket/livechat/event.go b/app/http/endpoints/api/ticket/livechat/event.go new file mode 100644 index 0000000..b797d65 --- /dev/null +++ b/app/http/endpoints/api/ticket/livechat/event.go @@ -0,0 +1,32 @@ +package livechat + +import ( + "encoding/json" +) + +type ( + EventType string + + Event struct { + Type EventType `json:"type"` + Data json.RawMessage `json:"data,omitempty"` + } + + AuthData struct { + Token string `json:"token"` + } + + ErrorMessage struct { + Error string `json:"error"` + } +) + +const ( + EventTypeAuth EventType = "auth" + EventTypeAuthenticated EventType = "authenticated" + EventTypeMessage EventType = "message" +) + +func NewErrorMessage(message string) ErrorMessage { + return ErrorMessage{message} +} diff --git a/app/http/endpoints/api/ticket/livechat/eventhandler.go b/app/http/endpoints/api/ticket/livechat/eventhandler.go new file mode 100644 index 0000000..cd49d02 --- /dev/null +++ b/app/http/endpoints/api/ticket/livechat/eventhandler.go @@ -0,0 +1,112 @@ +package livechat + +import ( + "encoding/json" + "errors" + "fmt" + "github.com/TicketsBot/GoPanel/botcontext" + "github.com/TicketsBot/GoPanel/config" + dbclient "github.com/TicketsBot/GoPanel/database" + "github.com/TicketsBot/GoPanel/internal/api" + "github.com/TicketsBot/GoPanel/rpc" + "github.com/TicketsBot/GoPanel/utils" + "github.com/TicketsBot/common/premium" + "github.com/golang-jwt/jwt" + "net/http" + "strconv" +) + +func (c *Client) HandleEvent(event Event) error { + switch event.Type { + case EventTypeAuth: + var data AuthData + if err := json.Unmarshal(event.Data, &data); err != nil { + c.Write(NewErrorMessage("Malformed event payload")) + _ = c.Ws.Close() + c.Flush() + return err + } + + if err := c.handleAuthEvent(data); err != nil { + return err + } + } + + return nil +} + +func (c *Client) handleAuthEvent(data AuthData) error { + if c.Authenticated { + return api.NewErrorWithMessage(http.StatusBadRequest, errors.New("Already authenticated"), "Already authenticated") + } + + token, err := jwt.Parse(data.Token, func(token *jwt.Token) (interface{}, error) { + if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) + } + + return []byte(config.Conf.Server.Secret), nil + }) + if err != nil { + return api.NewErrorWithMessage(http.StatusUnauthorized, err, "Invalid token") + } + + claims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return api.NewErrorWithMessage(http.StatusUnauthorized, err, "Invalid token data") + } + + userIdStr, ok := claims["userid"].(string) + if !ok { + return api.NewErrorWithMessage(http.StatusUnauthorized, err, "Invalid token data") + } + + userId, err := strconv.ParseUint(userIdStr, 10, 64) + if err != nil { + return api.NewErrorWithMessage(http.StatusUnauthorized, err, "Invalid token data") + } + + // Get the ticket + ticket, err := dbclient.Client.Tickets.Get(c.TicketId, c.GuildId) + if err != nil { + return api.NewErrorWithMessage(http.StatusInternalServerError, err, "Error retrieving ticket data") + } + + if ticket.Id == 0 || ticket.GuildId == 0 || ticket.GuildId != c.GuildId { + return api.NewErrorWithMessage(http.StatusNotFound, err, "Ticket not found") + } + + // Verify the user has permissions to be here + hasPermission, requestErr := utils.HasPermissionToViewTicket(c.GuildId, userId, ticket) + if requestErr != nil { + return api.NewErrorWithMessage(http.StatusInternalServerError, err, "Error retrieving permission data") + } + + if !hasPermission { + return api.NewErrorWithMessage(http.StatusForbidden, err, "You do not have permission to view this ticket") + } + + // Check premium + botContext, err := botcontext.ContextForGuild(c.GuildId) + if err != nil { + return api.NewErrorWithMessage(http.StatusInternalServerError, err, "Error retrieving bot context") + } + + // Verify the guild is premium + premiumTier, err := rpc.PremiumClient.GetTierByGuildId(c.GuildId, true, botContext.Token, botContext.RateLimiter) + if err != nil { + return api.NewErrorWithMessage(http.StatusInternalServerError, err, "Error retrieving premium tier") + } + + if premiumTier == premium.None { + return api.NewErrorWithMessage(http.StatusPaymentRequired, err, "Live-chat requires premium to use") + } + + c.Authenticated = true + + c.Write(Event{ + Type: EventTypeAuthenticated, + }) + + return nil +} diff --git a/app/http/endpoints/api/ticket/livechat/livechat.go b/app/http/endpoints/api/ticket/livechat/livechat.go new file mode 100644 index 0000000..43d44d3 --- /dev/null +++ b/app/http/endpoints/api/ticket/livechat/livechat.go @@ -0,0 +1,44 @@ +package livechat + +import ( + "github.com/TicketsBot/GoPanel/config" + "github.com/TicketsBot/GoPanel/utils" + "github.com/gin-gonic/gin" + "github.com/gorilla/websocket" + "net/http" + "strconv" +) + +var upgrader = websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + return r.Header.Get("Origin") == config.Conf.Server.BaseUrl + }, +} + +func GetLiveChatHandler(sm *SocketManager) gin.HandlerFunc { + return func(c *gin.Context) { + conn, err := upgrader.Upgrade(c.Writer, c.Request, nil) + if err != nil { + return + } + + guildId, err := strconv.ParseUint(c.Param("id"), 10, 64) + if err != nil { + c.JSON(400, utils.ErrorJson(err)) + return + } + + ticketId, err := strconv.Atoi(c.Param("ticketId")) + if err != nil { + c.JSON(400, utils.ErrorJson(err)) + return + } + + client := NewClient(sm, conn, c, guildId, ticketId) + sm.register <- client + go client.StartReadLoop() + go client.StartWriteLoop() + } +} diff --git a/app/http/endpoints/api/ticket/livechat/manager.go b/app/http/endpoints/api/ticket/livechat/manager.go new file mode 100644 index 0000000..9b0cd2f --- /dev/null +++ b/app/http/endpoints/api/ticket/livechat/manager.go @@ -0,0 +1,84 @@ +package livechat + +import ( + "encoding/json" + "github.com/TicketsBot/common/chatrelay" +) + +type ( + SocketManager struct { + clients map[uint64][]*Client // Remember: A client might not be authenticated! + messages chan chatrelay.MessageData + register chan *Client + unregister chan *Client + } +) + +func NewSocketManager() *SocketManager { + return &SocketManager{ + clients: map[uint64][]*Client{}, + messages: make(chan chatrelay.MessageData), + register: make(chan *Client), + unregister: make(chan *Client), + } +} + +func (sm *SocketManager) Run() { + for { + select { + case client := <-sm.register: + guildClients := sm.clients[client.GuildId] + guildClients = append(guildClients, client) + sm.clients[client.GuildId] = guildClients + case client := <-sm.unregister: + guildClients := sm.clients[client.GuildId] + if len(guildClients) == 0 { + continue // TODO: Warn + } + + i := -1 + for index, el := range guildClients { + if el == client { + i = index + break + } + } + + if i != -1 { + guildClients = guildClients[:i+copy(guildClients[i:], guildClients[i+1:])] + } + + sm.clients[client.GuildId] = guildClients + case msg := <-sm.messages: + guildClients, ok := sm.clients[msg.Ticket.GuildId] + if !ok || len(guildClients) == 0 { // No clients connected to this API server for this guild + continue + } + + encoded, err := json.Marshal(msg.Message) + if err != nil { + continue // TODO: Warn + } + + for _, client := range guildClients { + if !client.Authenticated { + continue + } + + // Should already be filtered by guild ID, but here we are filtering by ticket ID for the first time + if client.GuildId != msg.Ticket.GuildId || client.TicketId != msg.Ticket.Id { + continue + } + + client.Write(Event{ + Type: EventTypeMessage, + Data: encoded, + }) + } + } + } +} + +func (sm *SocketManager) BroadcastMessage(message chatrelay.MessageData) { + sm.messages <- message +} diff --git a/app/http/endpoints/root/webchatws.go b/app/http/endpoints/root/webchatws.go deleted file mode 100644 index 3cd6780..0000000 --- a/app/http/endpoints/root/webchatws.go +++ /dev/null @@ -1,193 +0,0 @@ -package root - -import ( - "encoding/json" - "fmt" - "github.com/TicketsBot/GoPanel/botcontext" - "github.com/TicketsBot/GoPanel/config" - "github.com/TicketsBot/GoPanel/rpc" - "github.com/TicketsBot/GoPanel/utils" - "github.com/TicketsBot/common/permission" - "github.com/TicketsBot/common/premium" - "github.com/gin-gonic/gin" - "github.com/golang-jwt/jwt" - "github.com/gorilla/websocket" - "net/http" - "strconv" - "sync" - "time" -) - -var upgrader = websocket.Upgrader{ - ReadBufferSize: 1024, - WriteBufferSize: 1024, - CheckOrigin: func(r *http.Request) bool { - return r.Header.Get("Origin") == config.Conf.Server.BaseUrl - }, -} - -var SocketsLock sync.RWMutex -var Sockets []*Socket - -type ( - Socket struct { - Ws *websocket.Conn - GuildId uint64 - TicketId int - } - - WsEvent struct { - Type string - Data json.RawMessage - } - - AuthEvent struct { - GuildId uint64 `json:"guild_id,string"` - TicketId int `json:"ticket_id"` - Token string `json:"token"` - } -) - -var timeout = time.Second * 60 - -func WebChatWs(ctx *gin.Context) { - conn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil) - if err != nil { - return - } - - socket := &Socket{ - Ws: conn, - } - - SocketsLock.Lock() - Sockets = append(Sockets, socket) - SocketsLock.Unlock() - - conn.SetCloseHandler(func(code int, text string) error { - i := -1 - SocketsLock.Lock() - defer SocketsLock.Unlock() - - for index, element := range Sockets { - if element == socket { - i = index - break - } - } - - if i != -1 { - Sockets = Sockets[:i+copy(Sockets[i:], Sockets[i+1:])] - } - - return nil - }) - - lastResponse := time.Now() - conn.SetPongHandler(func(a string) error { - lastResponse = time.Now() - return nil - }) - - go func() { - // We can let this func call the CloseHandler - for { - err := conn.WriteMessage(websocket.PingMessage, []byte("keepalive")) - if err != nil { - conn.Close() - conn.CloseHandler()(1000, "") - return - } - - time.Sleep(timeout / 2) - if time.Since(lastResponse) > timeout { - conn.Close() - conn.CloseHandler()(1000, "") - return - } - } - }() - - for { - var event WsEvent - err := conn.ReadJSON(&event) - if err != nil { - break - } - - if socket.GuildId == 0 && event.Type != "auth" { - conn.Close() - break - } else if event.Type == "auth" { - var authData AuthEvent - if err := json.Unmarshal(event.Data, &authData); err != nil { - conn.Close() - return - } - - token, err := jwt.Parse(authData.Token, func(token *jwt.Token) (interface{}, error) { - if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok { - return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"]) - } - - return []byte(config.Conf.Server.Secret), nil - }) - - claims, ok := token.Claims.(jwt.MapClaims) - if !ok { - conn.Close() - return - } - - userIdStr, ok := claims["userid"].(string) - if !ok { - conn.Close() - return - } - - userId, err := strconv.ParseUint(userIdStr, 10, 64) - if err != nil { - conn.Close() - return - } - - // Verify the user has permissions to be here - permLevel, err := utils.GetPermissionLevel(authData.GuildId, userId) - if err != nil { - conn.Close() - return - } - - if permLevel < permission.Admin { - conn.Close() - return - } - - botContext, err := botcontext.ContextForGuild(authData.GuildId) - if err != nil { - ctx.JSON(500, gin.H{ - "success": false, - "error": err.Error(), - }) - return - } - - // Verify the guild is premium - premiumTier, err := rpc.PremiumClient.GetTierByGuildId(authData.GuildId, true, botContext.Token, botContext.RateLimiter) - if err != nil { - ctx.JSON(500, utils.ErrorJson(err)) - return - } - - if premiumTier == premium.None { - conn.Close() - return - } - - SocketsLock.Lock() - socket.GuildId = authData.GuildId - socket.TicketId = authData.TicketId - SocketsLock.Unlock() - } - } -} diff --git a/app/http/server.go b/app/http/server.go index 4661062..9b8a8a8 100644 --- a/app/http/server.go +++ b/app/http/server.go @@ -12,6 +12,7 @@ import ( api_tags "github.com/TicketsBot/GoPanel/app/http/endpoints/api/tags" api_team "github.com/TicketsBot/GoPanel/app/http/endpoints/api/team" api_ticket "github.com/TicketsBot/GoPanel/app/http/endpoints/api/ticket" + "github.com/TicketsBot/GoPanel/app/http/endpoints/api/ticket/livechat" api_transcripts "github.com/TicketsBot/GoPanel/app/http/endpoints/api/transcripts" api_whitelabel "github.com/TicketsBot/GoPanel/app/http/endpoints/api/whitelabel" "github.com/TicketsBot/GoPanel/app/http/endpoints/root" @@ -25,7 +26,7 @@ import ( "time" ) -func StartServer() { +func StartServer(sm *livechat.SocketManager) { log.Println("Starting HTTP server") router := gin.Default() @@ -71,8 +72,6 @@ func StartServer() { ctx.String(200, "Disallow: /") }) - router.GET("/webchat", root.WebChatWs) - router.POST("/callback", middleware.VerifyXTicketsHeader, root.CallbackHandler) router.POST("/logout", middleware.VerifyXTicketsHeader, middleware.AuthenticateToken, root.LogoutHandler) @@ -155,6 +154,9 @@ func StartServer() { guildAuthApiSupport.POST("/tickets/:ticketId", rl(middleware.RateLimitTypeGuild, 5, time.Second*5), api_ticket.SendMessage) guildAuthApiSupport.DELETE("/tickets/:ticketId", api_ticket.CloseTicket) + // Websockets do not support headers: so we must implement authentication over the WS connection + router.GET("/api/:id/tickets/:ticketId/live-chat", livechat.GetLiveChatHandler(sm)) + guildAuthApiSupport.GET("/tags", api_tags.TagsListHandler) guildAuthApiSupport.PUT("/tags", api_tags.CreateTag) guildAuthApiSupport.DELETE("/tags", api_tags.DeleteTag) diff --git a/cmd/api/main.go b/cmd/api/main.go index 77bd161..25e07e3 100644 --- a/cmd/api/main.go +++ b/cmd/api/main.go @@ -3,7 +3,7 @@ package main import ( "fmt" app "github.com/TicketsBot/GoPanel/app/http" - "github.com/TicketsBot/GoPanel/app/http/endpoints/root" + "github.com/TicketsBot/GoPanel/app/http/endpoints/api/ticket/livechat" "github.com/TicketsBot/GoPanel/config" "github.com/TicketsBot/GoPanel/database" "github.com/TicketsBot/GoPanel/redis" @@ -59,7 +59,11 @@ func main() { fmt.Println("Connecting to Redis...") redis.Client = redis.NewRedisClient() - go ListenChat(redis.Client) + + socketManager := livechat.NewSocketManager() + go socketManager.Run() + + go ListenChat(redis.Client, socketManager) if !config.Conf.Debug { rpc.PremiumClient = premium.NewPremiumLookupClient( @@ -74,23 +78,15 @@ func main() { } fmt.Println("Starting server...") - app.StartServer() + app.StartServer(socketManager) } -func ListenChat(client redis.RedisClient) { +func ListenChat(client redis.RedisClient, sm *livechat.SocketManager) { ch := make(chan chatrelay.MessageData) go chatrelay.Listen(client.Client, ch) for event := range ch { - root.SocketsLock.RLock() - for _, socket := range root.Sockets { - if socket.GuildId == event.Ticket.GuildId && socket.TicketId == event.Ticket.Id { - if err := socket.Ws.WriteJSON(event.Message); err != nil { - fmt.Println(err.Error()) - } - } - } - root.SocketsLock.RUnlock() + sm.BroadcastMessage(event) } } diff --git a/frontend/src/views/TicketView.svelte b/frontend/src/views/TicketView.svelte index 2f19c1b..561f562 100644 --- a/frontend/src/views/TicketView.svelte +++ b/frontend/src/views/TicketView.svelte @@ -48,7 +48,7 @@ let isPremium = false; let container; - let WS_URL = env.WS_URL || 'ws://172.26.50.75:3000'; + let WS_URL = env.WS_URL || 'ws://localhost:3000'; function scrollContainer() { container.scrollTop = container.scrollHeight; @@ -84,23 +84,23 @@ } function connectWebsocket() { - const ws = new WebSocket(`${WS_URL}/webchat`); + const ws = new WebSocket(`${WS_URL}/api/${guildId}/tickets/${ticketId}/live-chat`); ws.onopen = () => { ws.send(JSON.stringify({ "type": "auth", "data": { - "guild_id": guildId, - "ticket_id": ticketId, "token": getToken(), } })); }; ws.onmessage = (evt) => { - const data = JSON.parse(evt.data); - messages = [...messages, data]; - scrollContainer(); + const payload = JSON.parse(evt.data); + if (payload.type === "message") { + messages = [...messages, payload.data]; + scrollContainer(); + } }; }