Skip to content

Commit 70f2193

Browse files
committed
Force-merge from main
2 parents 878835b + 76f0d95 commit 70f2193

26 files changed

+1451
-735
lines changed

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,6 @@ test-scripts/
184184

185185
# Ruff cache
186186
.ruff_cache/
187+
188+
# Old files
189+
old/*

.pre-commit-config.yaml

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
default_language_version:
2-
python: python3.11
2+
python: python3.12
33
files: '.py'
44
exclude: ".env,.yml,.gitignore,.git,.md,.txt"
55
default_stages: [push, commit]
@@ -14,7 +14,7 @@ repos:
1414
additional_dependencies: ["bandit[toml]"]
1515

1616
- repo: https://github.com/psf/black
17-
rev: 24.3.0
17+
rev: 24.4.2
1818
hooks:
1919
- id: black
2020
name: Black
@@ -30,7 +30,7 @@ repos:
3030
stages: [commit]
3131

3232
- repo: https://github.com/astral-sh/ruff-pre-commit
33-
rev: v0.3.5
33+
rev: v0.4.5
3434
hooks:
3535
- id: ruff
3636
name: Ruff

bot/cogs/config.py

+269-6
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,105 @@
11
from __future__ import annotations
22

3-
from typing import TYPE_CHECKING, Annotated, Optional, Union
3+
from typing import (
4+
TYPE_CHECKING,
5+
Annotated,
6+
Any,
7+
NamedTuple,
8+
Optional,
9+
Union,
10+
overload,
11+
)
412

513
import asyncpg
614
import discord
715
import msgspec
816
from async_lru import alru_cache
917
from discord import app_commands
1018
from discord.ext import commands
19+
from libs.tickets.utils import get_cached_thread
1120
from libs.utils import GuildContext
1221
from libs.utils.checks import bot_check_permissions, check_permissions
13-
from libs.utils.embeds import Embed
22+
from libs.utils.embeds import CooldownEmbed, Embed
23+
from libs.utils.pages import SimplePages
1424
from libs.utils.prefix import get_prefix
1525

1626
if TYPE_CHECKING:
27+
from cogs.tickets import Tickets
1728
from rodhaj import Rodhaj
1829

1930
UNKNOWN_ERROR_MESSAGE = (
2031
"An unknown error happened. Please contact the dev team for assistance"
2132
)
2233

2334

35+
class BlocklistTicket(NamedTuple):
36+
cog: Tickets
37+
thread: discord.Thread
38+
39+
40+
class BlocklistEntity(msgspec.Struct, frozen=True):
41+
bot: Rodhaj
42+
guild_id: int
43+
entity_id: int
44+
45+
def format(self) -> str:
46+
user = self.bot.get_user(self.entity_id)
47+
name = user.global_name if user else "Unknown"
48+
return f"{name} (ID: {self.entity_id})"
49+
50+
51+
class BlocklistPages(SimplePages):
52+
def __init__(self, entries: list[BlocklistEntity], *, ctx: GuildContext):
53+
converted = [entry.format() for entry in entries]
54+
super().__init__(converted, ctx=ctx)
55+
56+
57+
class Blocklist:
58+
def __init__(self, bot: Rodhaj):
59+
self.bot = bot
60+
self._blocklist: dict[int, BlocklistEntity] = {}
61+
62+
async def _load(self, connection: Union[asyncpg.Connection, asyncpg.Pool]):
63+
query = """
64+
SELECT guild_id, entity_id
65+
FROM blocklist;
66+
"""
67+
rows = await connection.fetch(query)
68+
return {
69+
row["entity_id"]: BlocklistEntity(bot=self.bot, **dict(row)) for row in rows
70+
}
71+
72+
async def load(self, connection: Optional[asyncpg.Connection] = None):
73+
try:
74+
self._blocklist = await self._load(connection or self.bot.pool)
75+
except Exception:
76+
self._blocklist = {}
77+
78+
@overload
79+
def get(self, key: int) -> Optional[BlocklistEntity]: ...
80+
81+
@overload
82+
def get(self, key: int) -> BlocklistEntity: ...
83+
84+
def get(self, key: int, default: Any = None) -> Optional[BlocklistEntity]:
85+
return self._blocklist.get(key, default)
86+
87+
def __contains__(self, item: int) -> bool:
88+
return item in self._blocklist
89+
90+
def __getitem__(self, item: int) -> BlocklistEntity:
91+
return self._blocklist[item]
92+
93+
def __len__(self) -> int:
94+
return len(self._blocklist)
95+
96+
def all(self) -> dict[int, BlocklistEntity]:
97+
return self._blocklist
98+
99+
def replace(self, blocklist: dict[int, BlocklistEntity]) -> None:
100+
self._blocklist = blocklist
101+
102+
24103
# Msgspec Structs are usually extremely fast compared to slotted classes
25104
class GuildConfig(msgspec.Struct):
26105
bot: Rodhaj
@@ -30,7 +109,6 @@ class GuildConfig(msgspec.Struct):
30109
logging_channel_id: int
31110
logging_broadcast_url: str
32111
ticket_broadcast_url: str
33-
locked: bool = False
34112

35113
@property
36114
def category_channel(self) -> Optional[discord.CategoryChannel]:
@@ -74,7 +152,7 @@ async def get_ticket_webhook(self) -> Optional[discord.Webhook]:
74152
@alru_cache()
75153
async def get_config(self) -> Optional[GuildConfig]:
76154
query = """
77-
SELECT id, category_id, ticket_channel_id, logging_channel_id, logging_broadcast_url, ticket_broadcast_url, locked
155+
SELECT id, category_id, ticket_channel_id, logging_channel_id, logging_broadcast_url, ticket_broadcast_url
78156
FROM guild_config
79157
WHERE id = $1;
80158
"""
@@ -100,8 +178,8 @@ class SetupFlags(commands.FlagConverter):
100178

101179

102180
class PrefixConverter(commands.Converter):
103-
async def convert(self, ctx: commands.Context, argument: str):
104-
user_id = ctx.bot.user.id
181+
async def convert(self, ctx: GuildContext, argument: str):
182+
user_id = ctx.bot.user.id # type: ignore # Already logged in by this time
105183
if argument.startswith((f"<@{user_id}>", f"<@!{user_id}>")):
106184
raise commands.BadArgument("That is a reserved prefix already in use.")
107185
if len(argument) > 100:
@@ -139,6 +217,39 @@ def clean_prefixes(self, prefixes: Union[str, list[str]]) -> str:
139217

140218
return ", ".join(f"`{prefix}`" for prefix in prefixes[2:])
141219

220+
### Blocklist Utilities
221+
222+
async def can_be_blocked(self, ctx: GuildContext, entity: discord.Member) -> bool:
223+
if entity.id == ctx.author.id or await self.bot.is_owner(entity) or entity.bot:
224+
return False
225+
226+
# Hierarchy check
227+
if (
228+
isinstance(ctx.author, discord.Member)
229+
and entity.top_role > ctx.author.top_role
230+
):
231+
return False
232+
233+
return True
234+
235+
async def get_block_ticket(
236+
self, entity: discord.Member
237+
) -> Optional[BlocklistTicket]:
238+
tickets_cog: Optional[Tickets] = self.bot.get_cog("Tickets") # type: ignore
239+
cached_ticket = await get_cached_thread(self.bot, entity.id)
240+
if not tickets_cog or not cached_ticket:
241+
return
242+
243+
return BlocklistTicket(cog=tickets_cog, thread=cached_ticket.thread)
244+
245+
### Misc Utilities
246+
async def _handle_error(
247+
self, ctx: GuildContext, error: commands.CommandError
248+
) -> None:
249+
if isinstance(error, commands.CommandOnCooldown):
250+
embed = CooldownEmbed(error.retry_after)
251+
await ctx.send(embed=embed)
252+
142253
@check_permissions(manage_guild=True)
143254
@bot_check_permissions(manage_channels=True, manage_webhooks=True)
144255
@commands.guild_only()
@@ -236,6 +347,13 @@ async def setup(self, ctx: GuildContext, *, flags: SetupFlags) -> None:
236347
), # U+2705 White Heavy Check Mark
237348
moderated=True,
238349
),
350+
discord.ForumTag(
351+
name="Locked",
352+
emoji=discord.PartialEmoji(
353+
name="\U0001f510"
354+
), # U+1F510 CLOSED LOCK WITH KEY
355+
moderated=True,
356+
),
239357
]
240358

241359
delete_reason = "Failed to create channel due to existing config"
@@ -353,6 +471,18 @@ async def delete(self, ctx: GuildContext) -> None:
353471
else:
354472
await ctx.send("Cancelling.")
355473

474+
@setup.error
475+
async def on_setup_error(
476+
self, ctx: GuildContext, error: commands.CommandError
477+
) -> None:
478+
await self._handle_error(ctx, error)
479+
480+
@delete.error
481+
async def on_delete_error(
482+
self, ctx: GuildContext, error: commands.CommandError
483+
) -> None:
484+
await self._handle_error(ctx, error)
485+
356486
@check_permissions(manage_guild=True)
357487
@commands.guild_only()
358488
@config.group(name="prefix", fallback="info")
@@ -446,6 +576,139 @@ async def prefix_delete(
446576
else:
447577
await ctx.send("Confirmation cancelled. Please try again")
448578

579+
# In order to prevent abuse, 4 checks must be performed:
580+
# 1. Permissions check
581+
# 2. Is the selected entity higher than the author's current hierarchy? (in terms of role and members)
582+
# 3. Is the bot itself the entity getting blocklisted?
583+
# 4. Is the author themselves trying to get blocklisted?
584+
# This system must be addressed with care as it is extremely dangerous
585+
# TODO: Add an history command to view past history of entity
586+
@check_permissions(manage_messages=True, manage_roles=True, moderate_members=True)
587+
@commands.guild_only()
588+
@commands.hybrid_group(name="blocklist", fallback="info")
589+
async def blocklist(self, ctx: GuildContext) -> None:
590+
"""Manages and views the current blocklist"""
591+
blocklist = self.bot.blocklist.all()
592+
pages = BlocklistPages([entry for entry in blocklist.values()], ctx=ctx)
593+
await pages.start()
594+
595+
@check_permissions(manage_messages=True, manage_roles=True, moderate_members=True)
596+
@blocklist.command(name="add")
597+
@app_commands.describe(
598+
entity="The member to add to the blocklist",
599+
)
600+
async def blocklist_add(
601+
self,
602+
ctx: GuildContext,
603+
entity: discord.Member,
604+
) -> None:
605+
"""Adds an member into the blocklist"""
606+
if not await self.can_be_blocked(ctx, entity):
607+
await ctx.send("Failed to block entity")
608+
return
609+
610+
block_ticket = await self.get_block_ticket(entity)
611+
if not block_ticket:
612+
await ctx.send(
613+
"Unable to obtain block ticket. Perhaps the user doesn't have an active ticket?"
614+
)
615+
return
616+
617+
blocklist = self.bot.blocklist.all().copy()
618+
blocklist[entity.id] = BlocklistEntity(
619+
bot=self.bot, guild_id=ctx.guild.id, entity_id=entity.id
620+
)
621+
query = """
622+
WITH blocklist_insert AS (
623+
INSERT INTO blocklist (guild_id, entity_id)
624+
VALUES ($1, $2)
625+
RETURNING entity_id
626+
)
627+
UPDATE tickets
628+
SET locked = true
629+
WHERE owner_id = (SELECT entity_id FROM blocklist_insert);
630+
"""
631+
lock_reason = f"{entity.global_name} is blocked from using Rodhaj"
632+
async with self.bot.pool.acquire() as connection:
633+
tr = connection.transaction()
634+
await tr.start()
635+
try:
636+
await connection.execute(query, ctx.guild.id, entity.id)
637+
except asyncpg.UniqueViolationError:
638+
del blocklist[entity.id]
639+
await tr.rollback()
640+
await ctx.send("User is already in the blocklist")
641+
except Exception:
642+
del blocklist[entity.id]
643+
await tr.rollback()
644+
await ctx.send("Unable to block user")
645+
else:
646+
self.bot.metrics.features.blocked_users.inc()
647+
await tr.commit()
648+
self.bot.blocklist.replace(blocklist)
649+
650+
await block_ticket.cog.soft_lock_ticket(
651+
block_ticket.thread, lock_reason
652+
)
653+
await ctx.send(f"{entity.mention} has been blocked")
654+
655+
@check_permissions(manage_messages=True, manage_roles=True, moderate_members=True)
656+
@blocklist.command(name="remove")
657+
@app_commands.describe(entity="The member to remove from the blocklist")
658+
async def blocklist_remove(self, ctx: GuildContext, entity: discord.Member) -> None:
659+
"""Removes an member from the blocklist"""
660+
if not await self.can_be_blocked(ctx, entity):
661+
await ctx.send("Failed to unblock entity")
662+
return
663+
664+
block_ticket = await self.get_block_ticket(entity)
665+
if not block_ticket:
666+
# Must mean that they must have a thread cached
667+
await ctx.send("Unable to obtain block ticket.")
668+
return
669+
670+
blocklist = self.bot.blocklist.all().copy()
671+
try:
672+
del blocklist[entity.id]
673+
except KeyError:
674+
await ctx.send(
675+
"Unable to unblock user. Perhaps is the user not blocked yet?"
676+
)
677+
return
678+
679+
# As the first line catches the errors
680+
# when we delete an result in our cache,
681+
# it doesn't really matter whether it's deleted or not actually.
682+
# it would return the same thing - DELETE 0
683+
# Note: An timer would have to delete this technically
684+
query = """
685+
WITH blocklist_delete AS (
686+
DELETE FROM blocklist
687+
WHERE entity_id = $1
688+
RETURNING entity_id
689+
)
690+
UPDATE tickets
691+
SET locked = false
692+
WHERE owner_id = (SELECT entity_id FROM blocklist_delete);
693+
"""
694+
unlock_reason = f"{entity.global_name} is unblocked from using Rodhaj"
695+
async with self.bot.pool.acquire() as connection:
696+
tr = connection.transaction()
697+
await tr.start()
698+
try:
699+
await connection.execute(query, entity.id)
700+
except Exception:
701+
await tr.rollback()
702+
await ctx.send("Unable to block user")
703+
else:
704+
self.bot.metrics.features.blocked_users.dec()
705+
await tr.commit()
706+
self.bot.blocklist.replace(blocklist)
707+
await block_ticket.cog.soft_unlock_ticket(
708+
block_ticket.thread, unlock_reason
709+
)
710+
await ctx.send(f"{entity.mention} has been unblocked")
711+
449712

450713
async def setup(bot: Rodhaj) -> None:
451714
await bot.add_cog(Config(bot))

0 commit comments

Comments
 (0)