diff --git a/app/http/endpoints/api/getpermissionlevel.go b/app/http/endpoints/api/getpermissionlevel.go index 0ad1a18..65ad26f 100644 --- a/app/http/endpoints/api/getpermissionlevel.go +++ b/app/http/endpoints/api/getpermissionlevel.go @@ -1,57 +1,28 @@ package api import ( - "context" - "fmt" "github.com/TicketsBot/GoPanel/utils" - "github.com/TicketsBot/common/permission" "github.com/gin-gonic/gin" - "golang.org/x/sync/errgroup" "strconv" - "strings" ) func GetPermissionLevel(ctx *gin.Context) { userId := ctx.Keys["userid"].(uint64) - guilds := strings.Split(ctx.Query("guilds"), ",") - if len(guilds) > 100 { - ctx.JSON(400, gin.H{ - "success": false, - "error": "too many guilds", - }) + guildId, err := strconv.ParseUint(ctx.Query("guild"), 10, 64) + if err != nil { + ctx.JSON(400, utils.ErrorStr("Invalid guild ID")) return } - // TODO: This is insanely inefficient - - levels := make(map[string]permission.PermissionLevel) - - group, _ := errgroup.WithContext(context.Background()) - for _, raw := range guilds { - guildId, err := strconv.ParseUint(raw, 10, 64) - if err != nil { - ctx.JSON(400, gin.H{ - "success": false, - "error": fmt.Sprintf("invalid guild id: %s", raw), - }) - return - } - - group.Go(func() error { - level, err := utils.GetPermissionLevel(guildId, userId) - levels[strconv.FormatUint(guildId, 10)] = level - return err - }) - } - - if err := group.Wait(); err != nil { + permissionLevel, err := utils.GetPermissionLevel(guildId, userId) + if err != nil { ctx.JSON(500, utils.ErrorJson(err)) return } ctx.JSON(200, gin.H{ - "success": true, - "levels": levels, + "success": true, + "permission_level": permissionLevel, }) } diff --git a/app/http/endpoints/api/guilds.go b/app/http/endpoints/api/guilds.go index c089568..9403789 100644 --- a/app/http/endpoints/api/guilds.go +++ b/app/http/endpoints/api/guilds.go @@ -2,14 +2,14 @@ package api import ( "context" - "github.com/TicketsBot/GoPanel/database" + dbclient "github.com/TicketsBot/GoPanel/database" "github.com/TicketsBot/GoPanel/rpc/cache" "github.com/TicketsBot/GoPanel/utils" "github.com/TicketsBot/common/permission" syncutils "github.com/TicketsBot/common/utils" + "github.com/TicketsBot/database" "github.com/gin-gonic/gin" - "github.com/jackc/pgx/v4" - "github.com/rxdn/gdl/objects/guild" + "github.com/jackc/pgtype" "github.com/rxdn/gdl/rest/request" "golang.org/x/sync/errgroup" "sort" @@ -24,47 +24,50 @@ type wrappedGuild struct { func GetGuilds(ctx *gin.Context) { userId := ctx.Keys["userid"].(uint64) - guilds, err := database.Client.UserGuilds.Get(userId) + // Get all guilds the user is in + guilds, err := dbclient.Client.UserGuilds.Get(userId) + if err != nil { + ctx.JSON(500, utils.ErrorJson(err)) + return + } + + // Get the subset of guilds that the user is in that the bot is also in + guildIds := make([]uint64, len(guilds)) + guildMap := make(map[uint64]database.UserGuild) // Make a map of all guilds for O(1) access + for i, guild := range guilds { + guildIds[i] = guild.GuildId + guildMap[guild.GuildId] = guild + } + + botGuilds, err := getExistingGuilds(guildIds) if err != nil { ctx.JSON(500, utils.ErrorJson(err)) return } wg := syncutils.NewChannelWaitGroup() - wg.Add(len(guilds)) + wg.Add(len(botGuilds)) group, _ := errgroup.WithContext(context.Background()) ch := make(chan wrappedGuild) - for _, g := range guilds { - g := g + for _, guildId := range botGuilds { + guildId := guildId + g := guildMap[guildId] group.Go(func() error { defer wg.Done() - // verify bot is in guild - if err := cache.Instance.QueryRow(context.Background(), `SELECT 1 from guilds WHERE "guild_id" = $1`, g.GuildId).Scan(nil); err != nil { - if err == pgx.ErrNoRows { - return nil - } else { - return err - } - } - - fakeGuild := guild.Guild{ - Id: g.GuildId, - Owner: g.Owner, - Permissions: g.UserPermissions, - } - + // Determine the user's permission level in this guild + var permLevel permission.PermissionLevel if g.Owner { - fakeGuild.OwnerId = userId - } - - permLevel, err := utils.GetPermissionLevel(g.GuildId, userId) - if err != nil { - // If a Discord error occurs, just skip the server - if _, ok := err.(request.RestError); !ok { - return err + permLevel = permission.Admin + } else { + permLevel, err = utils.GetPermissionLevel(g.GuildId, userId) + if err != nil { + // If a Discord error occurs, just skip the server + if _, ok := err.(request.RestError); !ok { + return err + } } } @@ -96,7 +99,10 @@ func GetGuilds(ctx *gin.Context) { return nil }) - _ = group.Wait() // error not possible + if err := group.Wait(); err != nil { + ctx.JSON(500, utils.ErrorJson(err)) + return + } // sort sort.Slice(adminGuilds, func(i, j int) bool { @@ -106,24 +112,30 @@ func GetGuilds(ctx *gin.Context) { ctx.JSON(200, adminGuilds) } -/*func getAdminGuilds(userId uint64) ([]uint64, error) { - var guilds []uint64 +func getExistingGuilds(userGuilds []uint64) ([]uint64, error) { + query := `SELECT "guild_id" from guilds WHERE "guild_id" = ANY($1);` - // get guilds owned by user - query := `SELECT "guild_id" FROM guilds WHERE "data"->'owner_id' = '$1';` - rows, err := cache.Instance.Query(context.Background(), query, userId) + userGuildsArray := &pgtype.Int8Array{} + if err := userGuildsArray.Set(userGuilds); err != nil { + return nil, err + } + + rows, err := cache.Instance.Query(context.Background(), query, userGuildsArray) if err != nil { return nil, err } + defer rows.Close() + + var existingGuilds []uint64 for rows.Next() { var guildId uint64 if err := rows.Scan(&guildId); err != nil { return nil, err } - guilds = append(guilds, guildId) + existingGuilds = append(existingGuilds, guildId) } - database.Client.Permissions.GetSupport() -}*/ + return existingGuilds, nil +} diff --git a/app/http/endpoints/api/ticket/gettickets.go b/app/http/endpoints/api/ticket/gettickets.go index 6d3a5d6..b866169 100644 --- a/app/http/endpoints/api/ticket/gettickets.go +++ b/app/http/endpoints/api/ticket/gettickets.go @@ -4,6 +4,7 @@ import ( "context" "github.com/TicketsBot/GoPanel/database" "github.com/TicketsBot/GoPanel/rpc/cache" + "github.com/TicketsBot/GoPanel/utils" "github.com/gin-gonic/gin" "github.com/rxdn/gdl/objects/user" "golang.org/x/sync/errgroup" @@ -19,10 +20,7 @@ func GetTickets(ctx *gin.Context) { tickets, err := database.Client.Tickets.GetGuildOpenTickets(guildId) if err != nil { - ctx.AbortWithStatusJSON(500, gin.H{ - "success": false, - "error": err.Error(), - }) + ctx.JSON(500, utils.ErrorJson(err)) return } @@ -50,10 +48,7 @@ func GetTickets(ctx *gin.Context) { } if err := group.Wait(); err != nil { - ctx.AbortWithStatusJSON(500, gin.H{ - "success": false, - "error": err.Error(), - }) + ctx.JSON(500, utils.ErrorJson(err)) return } diff --git a/frontend/src/components/Guild.svelte b/frontend/src/components/Guild.svelte index 5b08bbd..208666d 100644 --- a/frontend/src/components/Guild.svelte +++ b/frontend/src/components/Guild.svelte @@ -38,8 +38,8 @@ } async function goto(guildId) { - const permissionLevels = await getPermissionLevel(guildId); - if (permissionLevels[guildId] === 2) { + const permissionLevel = await getPermissionLevel(guildId); + if (permissionLevel === 2) { window.location.href = `/manage/${guildId}/settings`; } else { window.location.href = `/manage/${guildId}/transcripts`; @@ -47,13 +47,13 @@ } async function getPermissionLevel(guildId) { - const res = await axios.get(`${API_URL}/user/permissionlevel?guilds=${guildId}`); + const res = await axios.get(`${API_URL}/user/permissionlevel?guild=${guildId}`); if (res.status !== 200 || !res.data.success) { notifyError(res.data.error); return; } - return res.data.levels; + return res.data.permission_level; } diff --git a/utils/requestutils.go b/utils/requestutils.go index fa689c2..2a7670c 100644 --- a/utils/requestutils.go +++ b/utils/requestutils.go @@ -5,14 +5,14 @@ import ( "github.com/gin-gonic/gin" ) -func ErrorJson(err error) map[string]interface{} { +func ErrorJson(err error) map[string]any { return ErrorStr(err.Error()) } -func ErrorStr(err string, format ...interface{}) map[string]interface{} { - return gin.H { +func ErrorStr(err string, format ...any) map[string]any { + return gin.H{ "success": false, - "error": fmt.Sprintf(err, format...), + "error": fmt.Sprintf(err, format...), } }