Rework websockets

This commit is contained in:
rxdn 2023-09-13 17:06:15 +01:00
parent e266eda682
commit b7afa2373a
11 changed files with 431 additions and 218 deletions

View File

@ -55,7 +55,7 @@ func CloseTicket(ctx *gin.Context) {
} }
hasPermission, requestErr := utils.HasPermissionToViewTicket(guildId, userId, ticket) hasPermission, requestErr := utils.HasPermissionToViewTicket(guildId, userId, ticket)
if err != nil { if requestErr != nil {
ctx.JSON(requestErr.StatusCode, utils.ErrorJson(requestErr)) ctx.JSON(requestErr.StatusCode, utils.ErrorJson(requestErr))
return return
} }

View File

@ -64,7 +64,7 @@ func GetTicket(ctx *gin.Context) {
} }
hasPermission, requestErr := utils.HasPermissionToViewTicket(guildId, userId, ticket) hasPermission, requestErr := utils.HasPermissionToViewTicket(guildId, userId, ticket)
if err != nil { if requestErr != nil {
ctx.JSON(requestErr.StatusCode, utils.ErrorJson(requestErr)) ctx.JSON(requestErr.StatusCode, utils.ErrorJson(requestErr))
return return
} }

View File

@ -0,0 +1,136 @@
package livechat
import (
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"time"
)
type Client struct {
Manager *SocketManager
Ws *websocket.Conn
RequestCtx *gin.Context
Authenticated bool
GuildId uint64
TicketId int
tx chan any
flush chan chan struct{}
}
const (
messageSizeLimit = 1024 * 32
keepaliveFrequency = 45 * time.Second
keepaliveTimeout = 60 * time.Second
writeTimeout = 10 * time.Second
)
func NewClient(manager *SocketManager, ws *websocket.Conn, c *gin.Context, guildId uint64, ticketId int) *Client {
return &Client{
Manager: manager,
Ws: ws,
RequestCtx: c,
Authenticated: false,
GuildId: guildId,
TicketId: ticketId,
tx: make(chan any),
flush: make(chan chan struct{}),
}
}
func (c *Client) Close() {
close(c.tx)
}
func (c *Client) StartReadLoop() error {
defer func() {
c.Manager.unregister <- c
_ = c.Ws.Close()
c.Close()
}()
// Set up connection properties
c.Ws.SetReadLimit(messageSizeLimit)
if err := c.Ws.SetReadDeadline(time.Now().Add(keepaliveTimeout)); err != nil {
return err
}
c.Ws.SetPongHandler(func(appData string) error {
return c.Ws.SetReadDeadline(time.Now().Add(keepaliveTimeout))
})
for {
var event Event
if err := c.Ws.ReadJSON(&event); err != nil {
return err
}
if !c.Authenticated && event.Type != EventTypeAuth {
if err := c.Ws.WriteJSON(NewErrorMessage("Unauthorized")); err != nil {
return err
}
return nil
}
if err := c.HandleEvent(event); err != nil {
c.RequestCtx.Error(err)
c.Write(NewErrorMessage(err.Error()))
c.Flush()
_ = c.Ws.Close()
return err
}
}
}
func (c *Client) Write(msg any) {
c.tx <- msg
}
func (c *Client) StartWriteLoop() error {
ticker := time.NewTicker(keepaliveFrequency)
defer func() {
ticker.Stop()
_ = c.Ws.Close()
}()
for {
select {
case message, ok := <-c.tx:
if err := c.Ws.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
return err
}
if !ok { // Channel was closed
_ = c.Ws.WriteMessage(websocket.CloseMessage, []byte{})
return nil
} else {
if err := c.Ws.WriteJSON(message); err != nil {
return err
}
}
case <-ticker.C:
if err := c.Ws.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
return err
}
if err := c.Ws.WriteMessage(websocket.PingMessage, nil); err != nil {
return err
}
case ch := <-c.flush: // TODO: Channel order is random, there is a race condition here
ch <- struct{}{}
}
}
}
func (c *Client) Flush() {
ch := make(chan struct{})
c.flush <- ch
timer := time.After(time.Second)
select {
case <-ch:
return
case <-timer:
return
}
}

View File

@ -0,0 +1,32 @@
package livechat
import (
"encoding/json"
)
type (
EventType string
Event struct {
Type EventType `json:"type"`
Data json.RawMessage `json:"data,omitempty"`
}
AuthData struct {
Token string `json:"token"`
}
ErrorMessage struct {
Error string `json:"error"`
}
)
const (
EventTypeAuth EventType = "auth"
EventTypeAuthenticated EventType = "authenticated"
EventTypeMessage EventType = "message"
)
func NewErrorMessage(message string) ErrorMessage {
return ErrorMessage{message}
}

View File

@ -0,0 +1,112 @@
package livechat
import (
"encoding/json"
"errors"
"fmt"
"github.com/TicketsBot/GoPanel/botcontext"
"github.com/TicketsBot/GoPanel/config"
dbclient "github.com/TicketsBot/GoPanel/database"
"github.com/TicketsBot/GoPanel/internal/api"
"github.com/TicketsBot/GoPanel/rpc"
"github.com/TicketsBot/GoPanel/utils"
"github.com/TicketsBot/common/premium"
"github.com/golang-jwt/jwt"
"net/http"
"strconv"
)
func (c *Client) HandleEvent(event Event) error {
switch event.Type {
case EventTypeAuth:
var data AuthData
if err := json.Unmarshal(event.Data, &data); err != nil {
c.Write(NewErrorMessage("Malformed event payload"))
_ = c.Ws.Close()
c.Flush()
return err
}
if err := c.handleAuthEvent(data); err != nil {
return err
}
}
return nil
}
func (c *Client) handleAuthEvent(data AuthData) error {
if c.Authenticated {
return api.NewErrorWithMessage(http.StatusBadRequest, errors.New("Already authenticated"), "Already authenticated")
}
token, err := jwt.Parse(data.Token, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(config.Conf.Server.Secret), nil
})
if err != nil {
return api.NewErrorWithMessage(http.StatusUnauthorized, err, "Invalid token")
}
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
return api.NewErrorWithMessage(http.StatusUnauthorized, err, "Invalid token data")
}
userIdStr, ok := claims["userid"].(string)
if !ok {
return api.NewErrorWithMessage(http.StatusUnauthorized, err, "Invalid token data")
}
userId, err := strconv.ParseUint(userIdStr, 10, 64)
if err != nil {
return api.NewErrorWithMessage(http.StatusUnauthorized, err, "Invalid token data")
}
// Get the ticket
ticket, err := dbclient.Client.Tickets.Get(c.TicketId, c.GuildId)
if err != nil {
return api.NewErrorWithMessage(http.StatusInternalServerError, err, "Error retrieving ticket data")
}
if ticket.Id == 0 || ticket.GuildId == 0 || ticket.GuildId != c.GuildId {
return api.NewErrorWithMessage(http.StatusNotFound, err, "Ticket not found")
}
// Verify the user has permissions to be here
hasPermission, requestErr := utils.HasPermissionToViewTicket(c.GuildId, userId, ticket)
if requestErr != nil {
return api.NewErrorWithMessage(http.StatusInternalServerError, err, "Error retrieving permission data")
}
if !hasPermission {
return api.NewErrorWithMessage(http.StatusForbidden, err, "You do not have permission to view this ticket")
}
// Check premium
botContext, err := botcontext.ContextForGuild(c.GuildId)
if err != nil {
return api.NewErrorWithMessage(http.StatusInternalServerError, err, "Error retrieving bot context")
}
// Verify the guild is premium
premiumTier, err := rpc.PremiumClient.GetTierByGuildId(c.GuildId, true, botContext.Token, botContext.RateLimiter)
if err != nil {
return api.NewErrorWithMessage(http.StatusInternalServerError, err, "Error retrieving premium tier")
}
if premiumTier == premium.None {
return api.NewErrorWithMessage(http.StatusPaymentRequired, err, "Live-chat requires premium to use")
}
c.Authenticated = true
c.Write(Event{
Type: EventTypeAuthenticated,
})
return nil
}

View File

@ -0,0 +1,44 @@
package livechat
import (
"github.com/TicketsBot/GoPanel/config"
"github.com/TicketsBot/GoPanel/utils"
"github.com/gin-gonic/gin"
"github.com/gorilla/websocket"
"net/http"
"strconv"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return r.Header.Get("Origin") == config.Conf.Server.BaseUrl
},
}
func GetLiveChatHandler(sm *SocketManager) gin.HandlerFunc {
return func(c *gin.Context) {
conn, err := upgrader.Upgrade(c.Writer, c.Request, nil)
if err != nil {
return
}
guildId, err := strconv.ParseUint(c.Param("id"), 10, 64)
if err != nil {
c.JSON(400, utils.ErrorJson(err))
return
}
ticketId, err := strconv.Atoi(c.Param("ticketId"))
if err != nil {
c.JSON(400, utils.ErrorJson(err))
return
}
client := NewClient(sm, conn, c, guildId, ticketId)
sm.register <- client
go client.StartReadLoop()
go client.StartWriteLoop()
}
}

View File

@ -0,0 +1,84 @@
package livechat
import (
"encoding/json"
"github.com/TicketsBot/common/chatrelay"
)
type (
SocketManager struct {
clients map[uint64][]*Client // Remember: A client might not be authenticated!
messages chan chatrelay.MessageData
register chan *Client
unregister chan *Client
}
)
func NewSocketManager() *SocketManager {
return &SocketManager{
clients: map[uint64][]*Client{},
messages: make(chan chatrelay.MessageData),
register: make(chan *Client),
unregister: make(chan *Client),
}
}
func (sm *SocketManager) Run() {
for {
select {
case client := <-sm.register:
guildClients := sm.clients[client.GuildId]
guildClients = append(guildClients, client)
sm.clients[client.GuildId] = guildClients
case client := <-sm.unregister:
guildClients := sm.clients[client.GuildId]
if len(guildClients) == 0 {
continue // TODO: Warn
}
i := -1
for index, el := range guildClients {
if el == client {
i = index
break
}
}
if i != -1 {
guildClients = guildClients[:i+copy(guildClients[i:], guildClients[i+1:])]
}
sm.clients[client.GuildId] = guildClients
case msg := <-sm.messages:
guildClients, ok := sm.clients[msg.Ticket.GuildId]
if !ok || len(guildClients) == 0 { // No clients connected to this API server for this guild
continue
}
encoded, err := json.Marshal(msg.Message)
if err != nil {
continue // TODO: Warn
}
for _, client := range guildClients {
if !client.Authenticated {
continue
}
// Should already be filtered by guild ID, but here we are filtering by ticket ID for the first time
if client.GuildId != msg.Ticket.GuildId || client.TicketId != msg.Ticket.Id {
continue
}
client.Write(Event{
Type: EventTypeMessage,
Data: encoded,
})
}
}
}
}
func (sm *SocketManager) BroadcastMessage(message chatrelay.MessageData) {
sm.messages <- message
}

View File

@ -1,193 +0,0 @@
package root
import (
"encoding/json"
"fmt"
"github.com/TicketsBot/GoPanel/botcontext"
"github.com/TicketsBot/GoPanel/config"
"github.com/TicketsBot/GoPanel/rpc"
"github.com/TicketsBot/GoPanel/utils"
"github.com/TicketsBot/common/permission"
"github.com/TicketsBot/common/premium"
"github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt"
"github.com/gorilla/websocket"
"net/http"
"strconv"
"sync"
"time"
)
var upgrader = websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return r.Header.Get("Origin") == config.Conf.Server.BaseUrl
},
}
var SocketsLock sync.RWMutex
var Sockets []*Socket
type (
Socket struct {
Ws *websocket.Conn
GuildId uint64
TicketId int
}
WsEvent struct {
Type string
Data json.RawMessage
}
AuthEvent struct {
GuildId uint64 `json:"guild_id,string"`
TicketId int `json:"ticket_id"`
Token string `json:"token"`
}
)
var timeout = time.Second * 60
func WebChatWs(ctx *gin.Context) {
conn, err := upgrader.Upgrade(ctx.Writer, ctx.Request, nil)
if err != nil {
return
}
socket := &Socket{
Ws: conn,
}
SocketsLock.Lock()
Sockets = append(Sockets, socket)
SocketsLock.Unlock()
conn.SetCloseHandler(func(code int, text string) error {
i := -1
SocketsLock.Lock()
defer SocketsLock.Unlock()
for index, element := range Sockets {
if element == socket {
i = index
break
}
}
if i != -1 {
Sockets = Sockets[:i+copy(Sockets[i:], Sockets[i+1:])]
}
return nil
})
lastResponse := time.Now()
conn.SetPongHandler(func(a string) error {
lastResponse = time.Now()
return nil
})
go func() {
// We can let this func call the CloseHandler
for {
err := conn.WriteMessage(websocket.PingMessage, []byte("keepalive"))
if err != nil {
conn.Close()
conn.CloseHandler()(1000, "")
return
}
time.Sleep(timeout / 2)
if time.Since(lastResponse) > timeout {
conn.Close()
conn.CloseHandler()(1000, "")
return
}
}
}()
for {
var event WsEvent
err := conn.ReadJSON(&event)
if err != nil {
break
}
if socket.GuildId == 0 && event.Type != "auth" {
conn.Close()
break
} else if event.Type == "auth" {
var authData AuthEvent
if err := json.Unmarshal(event.Data, &authData); err != nil {
conn.Close()
return
}
token, err := jwt.Parse(authData.Token, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("unexpected signing method: %v", token.Header["alg"])
}
return []byte(config.Conf.Server.Secret), nil
})
claims, ok := token.Claims.(jwt.MapClaims)
if !ok {
conn.Close()
return
}
userIdStr, ok := claims["userid"].(string)
if !ok {
conn.Close()
return
}
userId, err := strconv.ParseUint(userIdStr, 10, 64)
if err != nil {
conn.Close()
return
}
// Verify the user has permissions to be here
permLevel, err := utils.GetPermissionLevel(authData.GuildId, userId)
if err != nil {
conn.Close()
return
}
if permLevel < permission.Admin {
conn.Close()
return
}
botContext, err := botcontext.ContextForGuild(authData.GuildId)
if err != nil {
ctx.JSON(500, gin.H{
"success": false,
"error": err.Error(),
})
return
}
// Verify the guild is premium
premiumTier, err := rpc.PremiumClient.GetTierByGuildId(authData.GuildId, true, botContext.Token, botContext.RateLimiter)
if err != nil {
ctx.JSON(500, utils.ErrorJson(err))
return
}
if premiumTier == premium.None {
conn.Close()
return
}
SocketsLock.Lock()
socket.GuildId = authData.GuildId
socket.TicketId = authData.TicketId
SocketsLock.Unlock()
}
}
}

View File

@ -12,6 +12,7 @@ import (
api_tags "github.com/TicketsBot/GoPanel/app/http/endpoints/api/tags" api_tags "github.com/TicketsBot/GoPanel/app/http/endpoints/api/tags"
api_team "github.com/TicketsBot/GoPanel/app/http/endpoints/api/team" api_team "github.com/TicketsBot/GoPanel/app/http/endpoints/api/team"
api_ticket "github.com/TicketsBot/GoPanel/app/http/endpoints/api/ticket" api_ticket "github.com/TicketsBot/GoPanel/app/http/endpoints/api/ticket"
"github.com/TicketsBot/GoPanel/app/http/endpoints/api/ticket/livechat"
api_transcripts "github.com/TicketsBot/GoPanel/app/http/endpoints/api/transcripts" api_transcripts "github.com/TicketsBot/GoPanel/app/http/endpoints/api/transcripts"
api_whitelabel "github.com/TicketsBot/GoPanel/app/http/endpoints/api/whitelabel" api_whitelabel "github.com/TicketsBot/GoPanel/app/http/endpoints/api/whitelabel"
"github.com/TicketsBot/GoPanel/app/http/endpoints/root" "github.com/TicketsBot/GoPanel/app/http/endpoints/root"
@ -25,7 +26,7 @@ import (
"time" "time"
) )
func StartServer() { func StartServer(sm *livechat.SocketManager) {
log.Println("Starting HTTP server") log.Println("Starting HTTP server")
router := gin.Default() router := gin.Default()
@ -71,8 +72,6 @@ func StartServer() {
ctx.String(200, "Disallow: /") ctx.String(200, "Disallow: /")
}) })
router.GET("/webchat", root.WebChatWs)
router.POST("/callback", middleware.VerifyXTicketsHeader, root.CallbackHandler) router.POST("/callback", middleware.VerifyXTicketsHeader, root.CallbackHandler)
router.POST("/logout", middleware.VerifyXTicketsHeader, middleware.AuthenticateToken, root.LogoutHandler) router.POST("/logout", middleware.VerifyXTicketsHeader, middleware.AuthenticateToken, root.LogoutHandler)
@ -155,6 +154,9 @@ func StartServer() {
guildAuthApiSupport.POST("/tickets/:ticketId", rl(middleware.RateLimitTypeGuild, 5, time.Second*5), api_ticket.SendMessage) guildAuthApiSupport.POST("/tickets/:ticketId", rl(middleware.RateLimitTypeGuild, 5, time.Second*5), api_ticket.SendMessage)
guildAuthApiSupport.DELETE("/tickets/:ticketId", api_ticket.CloseTicket) guildAuthApiSupport.DELETE("/tickets/:ticketId", api_ticket.CloseTicket)
// Websockets do not support headers: so we must implement authentication over the WS connection
router.GET("/api/:id/tickets/:ticketId/live-chat", livechat.GetLiveChatHandler(sm))
guildAuthApiSupport.GET("/tags", api_tags.TagsListHandler) guildAuthApiSupport.GET("/tags", api_tags.TagsListHandler)
guildAuthApiSupport.PUT("/tags", api_tags.CreateTag) guildAuthApiSupport.PUT("/tags", api_tags.CreateTag)
guildAuthApiSupport.DELETE("/tags", api_tags.DeleteTag) guildAuthApiSupport.DELETE("/tags", api_tags.DeleteTag)

View File

@ -3,7 +3,7 @@ package main
import ( import (
"fmt" "fmt"
app "github.com/TicketsBot/GoPanel/app/http" app "github.com/TicketsBot/GoPanel/app/http"
"github.com/TicketsBot/GoPanel/app/http/endpoints/root" "github.com/TicketsBot/GoPanel/app/http/endpoints/api/ticket/livechat"
"github.com/TicketsBot/GoPanel/config" "github.com/TicketsBot/GoPanel/config"
"github.com/TicketsBot/GoPanel/database" "github.com/TicketsBot/GoPanel/database"
"github.com/TicketsBot/GoPanel/redis" "github.com/TicketsBot/GoPanel/redis"
@ -59,7 +59,11 @@ func main() {
fmt.Println("Connecting to Redis...") fmt.Println("Connecting to Redis...")
redis.Client = redis.NewRedisClient() redis.Client = redis.NewRedisClient()
go ListenChat(redis.Client)
socketManager := livechat.NewSocketManager()
go socketManager.Run()
go ListenChat(redis.Client, socketManager)
if !config.Conf.Debug { if !config.Conf.Debug {
rpc.PremiumClient = premium.NewPremiumLookupClient( rpc.PremiumClient = premium.NewPremiumLookupClient(
@ -74,23 +78,15 @@ func main() {
} }
fmt.Println("Starting server...") fmt.Println("Starting server...")
app.StartServer() app.StartServer(socketManager)
} }
func ListenChat(client redis.RedisClient) { func ListenChat(client redis.RedisClient, sm *livechat.SocketManager) {
ch := make(chan chatrelay.MessageData) ch := make(chan chatrelay.MessageData)
go chatrelay.Listen(client.Client, ch) go chatrelay.Listen(client.Client, ch)
for event := range ch { for event := range ch {
root.SocketsLock.RLock() sm.BroadcastMessage(event)
for _, socket := range root.Sockets {
if socket.GuildId == event.Ticket.GuildId && socket.TicketId == event.Ticket.Id {
if err := socket.Ws.WriteJSON(event.Message); err != nil {
fmt.Println(err.Error())
}
}
}
root.SocketsLock.RUnlock()
} }
} }

View File

@ -48,7 +48,7 @@
let isPremium = false; let isPremium = false;
let container; let container;
let WS_URL = env.WS_URL || 'ws://172.26.50.75:3000'; let WS_URL = env.WS_URL || 'ws://localhost:3000';
function scrollContainer() { function scrollContainer() {
container.scrollTop = container.scrollHeight; container.scrollTop = container.scrollHeight;
@ -84,23 +84,23 @@
} }
function connectWebsocket() { function connectWebsocket() {
const ws = new WebSocket(`${WS_URL}/webchat`); const ws = new WebSocket(`${WS_URL}/api/${guildId}/tickets/${ticketId}/live-chat`);
ws.onopen = () => { ws.onopen = () => {
ws.send(JSON.stringify({ ws.send(JSON.stringify({
"type": "auth", "type": "auth",
"data": { "data": {
"guild_id": guildId,
"ticket_id": ticketId,
"token": getToken(), "token": getToken(),
} }
})); }));
}; };
ws.onmessage = (evt) => { ws.onmessage = (evt) => {
const data = JSON.parse(evt.data); const payload = JSON.parse(evt.data);
messages = [...messages, data]; if (payload.type === "message") {
messages = [...messages, payload.data];
scrollContainer(); scrollContainer();
}
}; };
} }