1
1
from __future__ import annotations
2
2
3
- from typing import TYPE_CHECKING , Annotated , Optional , Union
3
+ from typing import (
4
+ TYPE_CHECKING ,
5
+ Annotated ,
6
+ Any ,
7
+ NamedTuple ,
8
+ Optional ,
9
+ Union ,
10
+ overload ,
11
+ )
4
12
5
13
import asyncpg
6
14
import discord
7
15
import msgspec
8
16
from async_lru import alru_cache
9
17
from discord import app_commands
10
18
from discord .ext import commands
19
+ from libs .tickets .utils import get_cached_thread
11
20
from libs .utils import GuildContext
12
21
from libs .utils .checks import bot_check_permissions , check_permissions
13
22
from libs .utils .embeds import Embed
23
+ from libs .utils .pages import SimplePages
14
24
from libs .utils .prefix import get_prefix
15
25
16
26
if TYPE_CHECKING :
27
+ from cogs .tickets import Tickets
17
28
from rodhaj import Rodhaj
18
29
19
30
UNKNOWN_ERROR_MESSAGE = (
20
31
"An unknown error happened. Please contact the dev team for assistance"
21
32
)
22
33
23
34
35
+ class BlocklistTicket (NamedTuple ):
36
+ cog : Tickets
37
+ thread : discord .Thread
38
+
39
+
40
+ class BlocklistEntity (msgspec .Struct , frozen = True ):
41
+ bot : Rodhaj
42
+ guild_id : int
43
+ entity_id : int
44
+
45
+ def format (self ) -> str :
46
+ user = self .bot .get_user (self .entity_id )
47
+ name = user .global_name if user else "Unknown"
48
+ return f"{ name } (ID: { self .entity_id } )"
49
+
50
+
51
+ class BlocklistPages (SimplePages ):
52
+ def __init__ (self , entries : list [BlocklistEntity ], * , ctx : GuildContext ):
53
+ converted = [entry .format () for entry in entries ]
54
+ super ().__init__ (converted , ctx = ctx )
55
+
56
+
57
+ class Blocklist :
58
+ def __init__ (self , bot : Rodhaj ):
59
+ self .bot = bot
60
+ self ._blocklist : dict [int , BlocklistEntity ] = {}
61
+
62
+ async def _load (self , connection : Union [asyncpg .Connection , asyncpg .Pool ]):
63
+ query = """
64
+ SELECT guild_id, entity_id
65
+ FROM blocklist;
66
+ """
67
+ rows = await connection .fetch (query )
68
+ return {
69
+ row ["entity_id" ]: BlocklistEntity (bot = self .bot , ** dict (row )) for row in rows
70
+ }
71
+
72
+ async def load (self , connection : Optional [asyncpg .Connection ] = None ):
73
+ try :
74
+ self ._blocklist = await self ._load (connection or self .bot .pool )
75
+ except Exception :
76
+ self ._blocklist = {}
77
+
78
+ @overload
79
+ def get (self , key : int ) -> Optional [BlocklistEntity ]: ...
80
+
81
+ @overload
82
+ def get (self , key : int ) -> BlocklistEntity : ...
83
+
84
+ def get (self , key : int , default : Any = None ) -> Optional [BlocklistEntity ]:
85
+ return self ._blocklist .get (key , default )
86
+
87
+ def __contains__ (self , item : int ) -> bool :
88
+ return item in self ._blocklist
89
+
90
+ def __getitem__ (self , item : int ) -> BlocklistEntity :
91
+ return self ._blocklist [item ]
92
+
93
+ def __len__ (self ) -> int :
94
+ return len (self ._blocklist )
95
+
96
+ def all (self ) -> dict [int , BlocklistEntity ]:
97
+ return self ._blocklist
98
+
99
+ def replace (self , blocklist : dict [int , BlocklistEntity ]) -> None :
100
+ self ._blocklist = blocklist
101
+
102
+
24
103
# Msgspec Structs are usually extremely fast compared to slotted classes
25
104
class GuildConfig (msgspec .Struct ):
26
105
bot : Rodhaj
@@ -30,7 +109,6 @@ class GuildConfig(msgspec.Struct):
30
109
logging_channel_id : int
31
110
logging_broadcast_url : str
32
111
ticket_broadcast_url : str
33
- locked : bool = False
34
112
35
113
@property
36
114
def category_channel (self ) -> Optional [discord .CategoryChannel ]:
@@ -74,7 +152,7 @@ async def get_ticket_webhook(self) -> Optional[discord.Webhook]:
74
152
@alru_cache ()
75
153
async def get_config (self ) -> Optional [GuildConfig ]:
76
154
query = """
77
- SELECT id, category_id, ticket_channel_id, logging_channel_id, logging_broadcast_url, ticket_broadcast_url, locked
155
+ SELECT id, category_id, ticket_channel_id, logging_channel_id, logging_broadcast_url, ticket_broadcast_url
78
156
FROM guild_config
79
157
WHERE id = $1;
80
158
"""
@@ -100,8 +178,8 @@ class SetupFlags(commands.FlagConverter):
100
178
101
179
102
180
class PrefixConverter (commands .Converter ):
103
- async def convert (self , ctx : commands . Context , argument : str ):
104
- user_id = ctx .bot .user .id
181
+ async def convert (self , ctx : GuildContext , argument : str ):
182
+ user_id = ctx .bot .user .id # type: ignore # Already logged in by this time
105
183
if argument .startswith ((f"<@{ user_id } >" , f"<@!{ user_id } >" )):
106
184
raise commands .BadArgument ("That is a reserved prefix already in use." )
107
185
if len (argument ) > 100 :
@@ -139,6 +217,31 @@ def clean_prefixes(self, prefixes: Union[str, list[str]]) -> str:
139
217
140
218
return ", " .join (f"`{ prefix } `" for prefix in prefixes [2 :])
141
219
220
+ ### Blocklist Utilities
221
+
222
+ async def can_be_blocked (self , ctx : GuildContext , entity : discord .Member ) -> bool :
223
+ if entity .id == ctx .author .id or await self .bot .is_owner (entity ) or entity .bot :
224
+ return False
225
+
226
+ # Hierarchy check
227
+ if (
228
+ isinstance (ctx .author , discord .Member )
229
+ and entity .top_role > ctx .author .top_role
230
+ ):
231
+ return False
232
+
233
+ return True
234
+
235
+ async def get_block_ticket (
236
+ self , entity : discord .Member
237
+ ) -> Optional [BlocklistTicket ]:
238
+ tickets_cog : Optional [Tickets ] = self .bot .get_cog ("Tickets" ) # type: ignore
239
+ cached_ticket = await get_cached_thread (self .bot , entity .id )
240
+ if not tickets_cog or not cached_ticket :
241
+ return
242
+
243
+ return BlocklistTicket (cog = tickets_cog , thread = cached_ticket .thread )
244
+
142
245
@check_permissions (manage_guild = True )
143
246
@bot_check_permissions (manage_channels = True , manage_webhooks = True )
144
247
@commands .guild_only ()
@@ -236,6 +339,13 @@ async def setup(self, ctx: GuildContext, *, flags: SetupFlags) -> None:
236
339
), # U+2705 White Heavy Check Mark
237
340
moderated = True ,
238
341
),
342
+ discord .ForumTag (
343
+ name = "Locked" ,
344
+ emoji = discord .PartialEmoji (
345
+ name = "\U0001f510 "
346
+ ), # U+1F510 CLOSED LOCK WITH KEY
347
+ moderated = True ,
348
+ ),
239
349
]
240
350
241
351
delete_reason = "Failed to create channel due to existing config"
@@ -446,6 +556,137 @@ async def prefix_delete(
446
556
else :
447
557
await ctx .send ("Confirmation cancelled. Please try again" )
448
558
559
+ # In order to prevent abuse, 4 checks must be performed:
560
+ # 1. Permissions check
561
+ # 2. Is the selected entity higher than the author's current hierarchy? (in terms of role and members)
562
+ # 3. Is the bot itself the entity getting blocklisted?
563
+ # 4. Is the author themselves trying to get blocklisted?
564
+ # This system must be addressed with care as it is extremely dangerous
565
+ # TODO: Add an history command to view past history of entity
566
+ @check_permissions (manage_messages = True , manage_roles = True , moderate_members = True )
567
+ @commands .guild_only ()
568
+ @commands .hybrid_group (name = "blocklist" , fallback = "info" )
569
+ async def blocklist (self , ctx : GuildContext ) -> None :
570
+ """Manages and views the current blocklist"""
571
+ blocklist = self .bot .blocklist .all ()
572
+ pages = BlocklistPages ([entry for entry in blocklist .values ()], ctx = ctx )
573
+ await pages .start ()
574
+
575
+ @check_permissions (manage_messages = True , manage_roles = True , moderate_members = True )
576
+ @blocklist .command (name = "add" )
577
+ @app_commands .describe (
578
+ entity = "The member to add to the blocklist" ,
579
+ )
580
+ async def blocklist_add (
581
+ self ,
582
+ ctx : GuildContext ,
583
+ entity : discord .Member ,
584
+ ) -> None :
585
+ """Adds an member into the blocklist"""
586
+ if not await self .can_be_blocked (ctx , entity ):
587
+ await ctx .send ("Failed to block entity" )
588
+ return
589
+
590
+ block_ticket = await self .get_block_ticket (entity )
591
+ if not block_ticket :
592
+ await ctx .send (
593
+ "Unable to obtain block ticket. Perhaps the user doesn't have an active ticket?"
594
+ )
595
+ return
596
+
597
+ blocklist = self .bot .blocklist .all ().copy ()
598
+ blocklist [entity .id ] = BlocklistEntity (
599
+ bot = self .bot , guild_id = ctx .guild .id , entity_id = entity .id
600
+ )
601
+ query = """
602
+ WITH blocklist_insert AS (
603
+ INSERT INTO blocklist (guild_id, entity_id)
604
+ VALUES ($1, $2)
605
+ RETURNING entity_id
606
+ )
607
+ UPDATE tickets
608
+ SET locked = true
609
+ WHERE owner_id = (SELECT entity_id FROM blocklist_insert);
610
+ """
611
+ lock_reason = f"{ entity .global_name } is blocked from using Rodhaj"
612
+ async with self .bot .pool .acquire () as connection :
613
+ tr = connection .transaction ()
614
+ await tr .start ()
615
+ try :
616
+ await connection .execute (query , ctx .guild .id , entity .id )
617
+ except asyncpg .UniqueViolationError :
618
+ del blocklist [entity .id ]
619
+ await tr .rollback ()
620
+ await ctx .send ("User is already in the blocklist" )
621
+ except Exception :
622
+ del blocklist [entity .id ]
623
+ await tr .rollback ()
624
+ await ctx .send ("Unable to block user" )
625
+ else :
626
+ await tr .commit ()
627
+ self .bot .blocklist .replace (blocklist )
628
+
629
+ await block_ticket .cog .soft_lock_ticket (
630
+ block_ticket .thread , lock_reason
631
+ )
632
+ await ctx .send (f"{ entity .mention } has been blocked" )
633
+
634
+ @check_permissions (manage_messages = True , manage_roles = True , moderate_members = True )
635
+ @blocklist .command (name = "remove" )
636
+ @app_commands .describe (entity = "The member to remove from the blocklist" )
637
+ async def blocklist_remove (self , ctx : GuildContext , entity : discord .Member ) -> None :
638
+ """Removes an member from the blocklist"""
639
+ if not await self .can_be_blocked (ctx , entity ):
640
+ await ctx .send ("Failed to unblock entity" )
641
+ return
642
+
643
+ block_ticket = await self .get_block_ticket (entity )
644
+ if not block_ticket :
645
+ # Must mean that they must have a thread cached
646
+ await ctx .send ("Unable to obtain block ticket." )
647
+ return
648
+
649
+ blocklist = self .bot .blocklist .all ().copy ()
650
+ try :
651
+ del blocklist [entity .id ]
652
+ except KeyError :
653
+ await ctx .send (
654
+ "Unable to unblock user. Perhaps is the user not blocked yet?"
655
+ )
656
+ return
657
+
658
+ # As the first line catches the errors
659
+ # when we delete an result in our cache,
660
+ # it doesn't really matter whether it's deleted or not actually.
661
+ # it would return the same thing - DELETE 0
662
+ # Note: An timer would have to delete this technically
663
+ query = """
664
+ WITH blocklist_delete AS (
665
+ DELETE FROM blocklist
666
+ WHERE entity_id = $1
667
+ RETURNING entity_id
668
+ )
669
+ UPDATE tickets
670
+ SET locked = false
671
+ WHERE owner_id = (SELECT entity_id FROM blocklist_delete);
672
+ """
673
+ unlock_reason = f"{ entity .global_name } is unblocked from using Rodhaj"
674
+ async with self .bot .pool .acquire () as connection :
675
+ tr = connection .transaction ()
676
+ await tr .start ()
677
+ try :
678
+ await connection .execute (query , entity .id )
679
+ except Exception :
680
+ await tr .rollback ()
681
+ await ctx .send ("Unable to block user" )
682
+ else :
683
+ await tr .commit ()
684
+ self .bot .blocklist .replace (blocklist )
685
+ await block_ticket .cog .soft_unlock_ticket (
686
+ block_ticket .thread , unlock_reason
687
+ )
688
+ await ctx .send (f"{ entity .mention } has been unblocked" )
689
+
449
690
450
691
async def setup (bot : Rodhaj ) -> None :
451
692
await bot .add_cog (Config (bot ))
0 commit comments