8
8
from textwrap import dedent
9
9
from typing import Literal , NamedTuple , TYPE_CHECKING , get_args
10
10
11
- from discord import AllowedMentions , HTTPException , Interaction , Message , NotFound , Reaction , User , enums , ui
11
+ from discord import (
12
+ AllowedMentions ,
13
+ HTTPException ,
14
+ Interaction ,
15
+ Message ,
16
+ NotFound ,
17
+ Reaction ,
18
+ User ,
19
+ enums ,
20
+ ui ,
21
+ )
12
22
from discord .ext .commands import Cog , Command , Context , Converter , command , guild_only
23
+ from discord .ui import Button
13
24
from pydis_core .utils import interactions , paste_service
14
25
from pydis_core .utils .paste_service import PasteFile , send_to_paste_service
15
26
from pydis_core .utils .regex import FORMATTED_CODE_REGEX , RAW_CODE_REGEX
@@ -82,13 +93,21 @@ def print_last_line():
82
93
# The Snekbox commands' whitelists and blacklists.
83
94
NO_SNEKBOX_CHANNELS = (Channels .python_general ,)
84
95
NO_SNEKBOX_CATEGORIES = ()
85
- SNEKBOX_ROLES = (Roles .helpers , Roles .moderators , Roles .admins , Roles .owners , Roles .python_community , Roles .partners )
96
+ SNEKBOX_ROLES = (
97
+ Roles .helpers ,
98
+ Roles .moderators ,
99
+ Roles .admins ,
100
+ Roles .owners ,
101
+ Roles .python_community ,
102
+ Roles .partners ,
103
+ )
86
104
87
105
REDO_EMOJI = "\U0001f501 " # :repeat:
88
106
REDO_TIMEOUT = 30
89
107
90
108
SupportedPythonVersions = Literal ["3.12" , "3.13" , "3.13t" ]
91
109
110
+
92
111
class FilteredFiles (NamedTuple ):
93
112
allowed : list [FileAttachment ]
94
113
blocked : list [FileAttachment ]
@@ -119,7 +138,9 @@ async def convert(cls, ctx: Context, code: str) -> list[str]:
119
138
code , block , lang , delim = match .group ("code" , "block" , "lang" , "delim" )
120
139
codeblocks = [dedent (code )]
121
140
if block :
122
- info = (f"'{ lang } ' highlighted" if lang else "plain" ) + " code block"
141
+ info = (
142
+ f"'{ lang } ' highlighted" if lang else "plain"
143
+ ) + " code block"
123
144
else :
124
145
info = f"{ delim } -enclosed inline code"
125
146
else :
@@ -142,7 +163,9 @@ def __init__(
142
163
job : EvalJob ,
143
164
) -> None :
144
165
self .version_to_run = version_to_run
145
- super ().__init__ (label = f"Run in { self .version_to_run } " , style = enums .ButtonStyle .primary )
166
+ super ().__init__ (
167
+ label = f"Run in { self .version_to_run } " , style = enums .ButtonStyle .primary
168
+ )
146
169
147
170
self .snekbox_cog = snekbox_cog
148
171
self .ctx = ctx
@@ -163,7 +186,9 @@ async def callback(self, interaction: Interaction) -> None:
163
186
# The log arg on send_job will stop the actual job from running.
164
187
await interaction .message .delete ()
165
188
166
- await self .snekbox_cog .run_job (self .ctx , self .job .as_version (self .version_to_run ))
189
+ await self .snekbox_cog .run_job (
190
+ self .ctx , self .job .as_version (self .version_to_run )
191
+ )
167
192
168
193
169
194
class Snekbox (Cog ):
@@ -197,7 +222,9 @@ async def post_job(self, job: EvalJob) -> EvalResult:
197
222
"""Send a POST request to the Snekbox API to evaluate code and return the results."""
198
223
data = job .to_dict ()
199
224
200
- async with self .bot .http_session .post (URLs .snekbox_eval_api , json = data , raise_for_status = True ) as resp :
225
+ async with self .bot .http_session .post (
226
+ URLs .snekbox_eval_api , json = data , raise_for_status = True
227
+ ) as resp :
201
228
return EvalResult .from_dict (await resp .json ())
202
229
203
230
async def upload_output (self , output : str ) -> str | None :
@@ -257,7 +284,10 @@ async def format_output(
257
284
258
285
if ESCAPE_REGEX .findall (output ):
259
286
paste_link = await self .upload_output (original_output )
260
- return "Code block escape attempt detected; will not output result" , paste_link
287
+ return (
288
+ "Code block escape attempt detected; will not output result" ,
289
+ paste_link ,
290
+ )
261
291
262
292
truncated = False
263
293
lines = output .splitlines ()
@@ -269,12 +299,14 @@ async def format_output(
269
299
if len (lines ) > max_lines :
270
300
truncated = True
271
301
if len (lines ) == max_lines + 1 :
272
- lines = lines [:max_lines - 1 ]
302
+ lines = lines [: max_lines - 1 ]
273
303
else :
274
304
lines = lines [:max_lines ]
275
305
output = "\n " .join (lines )
276
306
if len (output ) >= max_chars :
277
- output = f"{ output [:max_chars ]} \n ... (truncated - too long, too many lines)"
307
+ output = (
308
+ f"{ output [:max_chars ]} \n ... (truncated - too long, too many lines)"
309
+ )
278
310
else :
279
311
output = f"{ output } \n ... (truncated - too many lines)"
280
312
elif len (output ) >= max_chars :
@@ -292,7 +324,9 @@ async def format_output(
292
324
293
325
return output , paste_link
294
326
295
- async def format_file_text (self , text_files : list [FileAttachment ], output : str ) -> str :
327
+ async def format_file_text (
328
+ self , text_files : list [FileAttachment ], output : str
329
+ ) -> str :
296
330
# Inline until budget, then upload to paste service
297
331
# Budget is shared with stdout, so subtract what we've already used
298
332
budget_lines = MAX_OUTPUT_BLOCK_LINES - (output .count ("\n " ) + 1 )
@@ -311,7 +345,7 @@ async def format_file_text(self, text_files: list[FileAttachment], output: str)
311
345
budget_lines ,
312
346
budget_chars ,
313
347
line_nums = False ,
314
- output_default = "[Empty]"
348
+ output_default = "[Empty]" ,
315
349
)
316
350
# With any link, use it (don't use budget)
317
351
if link_text :
@@ -325,24 +359,30 @@ async def format_file_text(self, text_files: list[FileAttachment], output: str)
325
359
326
360
def format_blocked_extensions (self , blocked : list [FileAttachment ]) -> str :
327
361
# Sort by length and then lexicographically to fit as many as possible before truncating.
328
- blocked_sorted = sorted (set (f .suffix for f in blocked ), key = lambda e : (len (e ), e ))
362
+ blocked_sorted = sorted (
363
+ set (f .suffix for f in blocked ), key = lambda e : (len (e ), e )
364
+ )
329
365
330
366
# Only no extension
331
367
if len (blocked_sorted ) == 1 and blocked_sorted [0 ] == "" :
332
368
blocked_msg = "Files with no extension can't be uploaded."
333
369
# Both
334
370
elif "" in blocked_sorted :
335
- blocked_str = self .join_blocked_extensions (ext for ext in blocked_sorted if ext )
336
- blocked_msg = (
337
- f"Files with no extension or disallowed extensions can't be uploaded: **{ blocked_str } **"
371
+ blocked_str = self .join_blocked_extensions (
372
+ ext for ext in blocked_sorted if ext
338
373
)
374
+ blocked_msg = f"Files with no extension or disallowed extensions can't be uploaded: **{ blocked_str } **"
339
375
else :
340
376
blocked_str = self .join_blocked_extensions (blocked_sorted )
341
- blocked_msg = f"Files with disallowed extensions can't be uploaded: **{ blocked_str } **"
377
+ blocked_msg = (
378
+ f"Files with disallowed extensions can't be uploaded: **{ blocked_str } **"
379
+ )
342
380
343
381
return f"\n { Emojis .failed_file } { blocked_msg } "
344
382
345
- def join_blocked_extensions (self , extensions : Iterable [str ], delimiter : str = ", " , char_limit : int = 100 ) -> str :
383
+ def join_blocked_extensions (
384
+ self , extensions : Iterable [str ], delimiter : str = ", " , char_limit : int = 100
385
+ ) -> str :
346
386
joined = ""
347
387
for ext in extensions :
348
388
cur_delimiter = delimiter if joined else ""
@@ -354,8 +394,9 @@ def join_blocked_extensions(self, extensions: Iterable[str], delimiter: str = ",
354
394
355
395
return joined
356
396
357
-
358
- def _filter_files (self , ctx : Context , files : list [FileAttachment ], blocked_exts : set [str ]) -> FilteredFiles :
397
+ def _filter_files (
398
+ self , ctx : Context , files : list [FileAttachment ], blocked_exts : set [str ]
399
+ ) -> FilteredFiles :
359
400
"""Filter to restrict files to allowed extensions. Return a named tuple of allowed and blocked files lists."""
360
401
# Filter files into allowed and blocked
361
402
blocked = []
@@ -370,7 +411,7 @@ def _filter_files(self, ctx: Context, files: list[FileAttachment], blocked_exts:
370
411
blocked_str = ", " .join (f .suffix for f in blocked )
371
412
log .info (
372
413
f"User '{ ctx .author } ' ({ ctx .author .id } ) uploaded blacklisted file(s) in eval: { blocked_str } " ,
373
- extra = {"attachment_list" : [f .filename for f in files ]}
414
+ extra = {"attachment_list" : [f .filename for f in files ]},
374
415
)
375
416
376
417
return FilteredFiles (allowed , blocked )
@@ -399,16 +440,18 @@ async def send_job(self, ctx: Context, job: EvalJob) -> Message:
399
440
400
441
# This is done to make sure the last line of output contains the error
401
442
# and the error is not manually printed by the author with a syntax error.
402
- if result .stdout .rstrip ().endswith ("EOFError: EOF when reading a line" ) and result .returncode == 1 :
403
- msg += "\n :warning: Note: `input` is not supported by the bot :warning:\n "
443
+ if (
444
+ result .stdout .rstrip ().endswith ("EOFError: EOF when reading a line" )
445
+ and result .returncode == 1
446
+ ):
447
+ msg += (
448
+ "\n :warning: Note: `input` is not supported by the bot :warning:\n "
449
+ )
404
450
405
451
# Skip output if it's empty and there are file uploads
406
452
if result .stdout or not result .has_files :
407
453
msg += f"\n ```ansi\n { output } \n ```"
408
454
409
- if paste_link :
410
- msg += f"\n Full output: { paste_link } "
411
-
412
455
# Additional files error message after output
413
456
if files_error := result .files_error_message :
414
457
msg += f"\n { files_error } "
@@ -423,9 +466,13 @@ async def send_job(self, ctx: Context, job: EvalJob) -> Message:
423
466
failed_files = [FileAttachment (name , b"" ) for name in result .failed_files ]
424
467
total_files = result .files + failed_files
425
468
if filter_cog :
426
- block_output , blocked_exts = await filter_cog .filter_snekbox_output (msg , total_files , ctx .message )
469
+ block_output , blocked_exts = await filter_cog .filter_snekbox_output (
470
+ msg , total_files , ctx .message
471
+ )
427
472
if block_output :
428
- return await ctx .send ("Attempt to circumvent filter detected. Moderator team has been alerted." )
473
+ return await ctx .send (
474
+ "Attempt to circumvent filter detected. Moderator team has been alerted."
475
+ )
429
476
430
477
# Filter file extensions
431
478
allowed , blocked = self ._filter_files (ctx , result .files , blocked_exts )
@@ -435,8 +482,18 @@ async def send_job(self, ctx: Context, job: EvalJob) -> Message:
435
482
436
483
# Upload remaining non-text files
437
484
files = [f .to_file () for f in allowed if f not in text_files ]
438
- allowed_mentions = AllowedMentions (everyone = False , roles = False , users = [ctx .author ])
485
+ allowed_mentions = AllowedMentions (
486
+ everyone = False , roles = False , users = [ctx .author ]
487
+ )
439
488
view = self .build_python_version_switcher_view (job .version , ctx , job )
489
+ if paste_link :
490
+ # Create a button
491
+ button = Button (
492
+ label = "View Full Output" , # Button text
493
+ url = paste_link , # The URL the button links to
494
+ )
495
+
496
+ view .add_item (button )
440
497
441
498
if ctx .message .channel == ctx .channel :
442
499
# Don't fail if the command invoking message was deleted.
@@ -446,15 +503,19 @@ async def send_job(self, ctx: Context, job: EvalJob) -> Message:
446
503
allowed_mentions = allowed_mentions ,
447
504
view = view ,
448
505
files = files ,
449
- reference = message
506
+ reference = message ,
450
507
)
451
508
else :
452
509
# The command was redirected so a reply wont work, send a normal message with a mention.
453
510
msg = f"{ ctx .author .mention } { msg } "
454
- response = await ctx .send (msg , allowed_mentions = allowed_mentions , view = view , files = files )
511
+ response = await ctx .send (
512
+ msg , allowed_mentions = allowed_mentions , view = view , files = files
513
+ )
455
514
view .message = response
456
515
457
- log .info (f"{ ctx .author } 's { job .name } job had a return code of { result .returncode } " )
516
+ log .info (
517
+ f"{ ctx .author } 's { job .name } job had a return code of { result .returncode } "
518
+ )
458
519
return response
459
520
460
521
async def continue_job (
@@ -472,15 +533,11 @@ async def continue_job(
472
533
with contextlib .suppress (NotFound ):
473
534
try :
474
535
_ , new_message = await self .bot .wait_for (
475
- "message_edit" ,
476
- check = _predicate_message_edit ,
477
- timeout = REDO_TIMEOUT
536
+ "message_edit" , check = _predicate_message_edit , timeout = REDO_TIMEOUT
478
537
)
479
538
await ctx .message .add_reaction (REDO_EMOJI )
480
539
await self .bot .wait_for (
481
- "reaction_add" ,
482
- check = _predicate_emoji_reaction ,
483
- timeout = 10
540
+ "reaction_add" , check = _predicate_emoji_reaction , timeout = 10
484
541
)
485
542
486
543
# Ensure the response that's about to be edited is still the most recent.
@@ -576,14 +633,14 @@ async def run_job(
576
633
bypass_roles = SNEKBOX_ROLES ,
577
634
categories = NO_SNEKBOX_CATEGORIES ,
578
635
channels = NO_SNEKBOX_CHANNELS ,
579
- ping_user = False
636
+ ping_user = False ,
580
637
)
581
638
async def eval_command (
582
639
self ,
583
640
ctx : Context ,
584
641
python_version : SupportedPythonVersions | None ,
585
642
* ,
586
- code : CodeblockConverter
643
+ code : CodeblockConverter ,
587
644
) -> None :
588
645
"""
589
646
Run Python code and get the results.
@@ -608,21 +665,25 @@ async def eval_command(
608
665
job = EvalJob .from_code ("\n " .join (code )).as_version (python_version )
609
666
await self .run_job (ctx , job )
610
667
611
- @command (name = "timeit" , aliases = ("ti" ,), usage = "[python_version] [setup_code] <code, ...>" )
668
+ @command (
669
+ name = "timeit" ,
670
+ aliases = ("ti" ,),
671
+ usage = "[python_version] [setup_code] <code, ...>" ,
672
+ )
612
673
@guild_only ()
613
674
@redirect_output (
614
675
destination_channel = Channels .bot_commands ,
615
676
bypass_roles = SNEKBOX_ROLES ,
616
677
categories = NO_SNEKBOX_CATEGORIES ,
617
678
channels = NO_SNEKBOX_CHANNELS ,
618
- ping_user = False
679
+ ping_user = False ,
619
680
)
620
681
async def timeit_command (
621
682
self ,
622
683
ctx : Context ,
623
684
python_version : SupportedPythonVersions | None ,
624
685
* ,
625
- code : CodeblockConverter
686
+ code : CodeblockConverter ,
626
687
) -> None :
627
688
"""
628
689
Profile Python Code to find execution time.
@@ -654,4 +715,8 @@ def predicate_message_edit(ctx: Context, old_msg: Message, new_msg: Message) ->
654
715
655
716
def predicate_emoji_reaction (ctx : Context , reaction : Reaction , user : User ) -> bool :
656
717
"""Return True if the reaction REDO_EMOJI was added by the context message author on this message."""
657
- return reaction .message .id == ctx .message .id and user .id == ctx .author .id and str (reaction ) == REDO_EMOJI
718
+ return (
719
+ reaction .message .id == ctx .message .id
720
+ and user .id == ctx .author .id
721
+ and str (reaction ) == REDO_EMOJI
722
+ )
0 commit comments