Skip to content

Commit 1a56c39

Browse files
Scope Variables In Unittest Template
Isolates the global variables in the unittest template into a function, to prevent them from being imported. Signed-off-by: Hassan Abouelela <[email protected]>
1 parent 1f7e602 commit 1a56c39

File tree

1 file changed

+37
-30
lines changed

1 file changed

+37
-30
lines changed

Diff for: resources/unittest_template.py

+37-30
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,16 @@
22
"""This template is used inside snekbox to evaluate and test user code."""
33
import ast
44
import base64
5+
import functools
56
import io
6-
import os
77
import sys
88
import traceback
9+
import typing
910
import unittest
1011
from itertools import chain
1112
from types import ModuleType, SimpleNamespace
1213
from typing import NoReturn
13-
from unittest import mock
14+
1415

1516
### USER CODE
1617

@@ -20,9 +21,10 @@ class RunnerTestCase(unittest.IsolatedAsyncioTestCase):
2021

2122

2223
normal_exit = False
24+
_EXIT_WRAPPER_TYPE = typing.Callable[[int], None]
2325

2426

25-
def _exit_sandbox(code: int) -> NoReturn:
27+
def _exit_sandbox(code: int, stdout: io.StringIO, result_writer: io.StringIO) -> NoReturn:
2628
"""
2729
Exit the sandbox by printing the result to the actual stdout and exit with the provided code.
2830
@@ -34,64 +36,69 @@ def _exit_sandbox(code: int) -> NoReturn:
3436
3537
137 can also be generated by NsJail when killing the process.
3638
"""
37-
print(RESULT.getvalue(), file=ORIGINAL_STDOUT, end="")
39+
print(result_writer.getvalue(), file=stdout, end="")
3840
global normal_exit
3941
normal_exit = True
4042
sys.exit(code)
4143

4244

43-
def _load_user_module() -> ModuleType:
45+
def _load_user_module(result_writer, exit_wrapper: _EXIT_WRAPPER_TYPE) -> ModuleType:
4446
"""Load the user code into a new module and return it."""
4547
code = base64.b64decode(USER_CODE).decode("utf8")
4648
try:
4749
ast.parse(code, "<input>")
4850
except SyntaxError:
49-
RESULT.write("".join(traceback.format_exception(*sys.exc_info(), limit=0)))
50-
_exit_sandbox(5)
51+
result_writer.write("".join(traceback.format_exception(*sys.exc_info(), limit=0)))
52+
exit_wrapper(5)
5153

5254
_module = ModuleType("module")
5355
exec(code, _module.__dict__)
5456

5557
return _module
5658

5759

58-
def _main() -> None:
60+
def _main(result_writer: io.StringIO, module: ModuleType, exit_wrapper: _EXIT_WRAPPER_TYPE) -> None:
5961
suite = unittest.defaultTestLoader.loadTestsFromTestCase(RunnerTestCase)
62+
globals()["module"] = module
6063
result = suite.run(unittest.TestResult())
6164

62-
RESULT.write(str(int(result.wasSuccessful())))
65+
result_writer.write(str(int(result.wasSuccessful())))
6366

6467
if not result.wasSuccessful():
65-
RESULT.write(
68+
result_writer.write(
6669
";".join(chain(
6770
(error[0]._testMethodName.removeprefix("test_") for error in result.errors),
6871
(failure[0]._testMethodName.removeprefix("test_") for failure in result.failures)
6972
))
7073
)
7174

72-
_exit_sandbox(0)
75+
exit_wrapper(0)
7376

7477

75-
try:
76-
# Fake file object not writing anything
77-
DEVNULL = SimpleNamespace(write=lambda *_: None, flush=lambda *_: None)
78+
def _entry():
79+
result_writer = io.StringIO()
80+
exit_wrapper = functools.partial(_exit_sandbox, stdout=sys.stdout, result_writer=result_writer)
7881

79-
RESULT = io.StringIO()
80-
ORIGINAL_STDOUT = sys.__stdout__
82+
try:
83+
# Fake file object not writing anything
84+
devnull = SimpleNamespace(write=lambda *_: None, flush=lambda *_: None)
8185

82-
# stdout/err is patched in order to control what is outputted by the runner
83-
sys.__stdout__ = sys.stdout = DEVNULL
84-
sys.__stderr__ = sys.stderr = DEVNULL
86+
# stdout/err is patched in order to control what is outputted by the runner
87+
sys.__stdout__ = sys.stdout = devnull
88+
sys.__stderr__ = sys.stderr = devnull
8589

86-
# Load the user code as a global module variable
87-
try:
88-
module = _load_user_module()
90+
# Load the user code as a global module variable
91+
try:
92+
module = _load_user_module(result_writer, exit_wrapper)
93+
except BaseException as e:
94+
result_writer.write(f"Uncaught exception while loading user code: {e}")
95+
exit_wrapper(6)
96+
97+
_main(result_writer, module, exit_wrapper)
8998
except BaseException as e:
90-
RESULT.write(f"Uncaught exception while loading user code: {e}")
91-
_exit_sandbox(6)
92-
_main()
93-
except BaseException as e:
94-
if isinstance(e, SystemExit) and normal_exit:
95-
raise e from None
96-
RESULT.write(f"Uncaught exception inside runner: {e}")
97-
_exit_sandbox(99)
99+
if isinstance(e, SystemExit) and normal_exit:
100+
raise e from None
101+
result_writer.write(f"Uncaught exception inside runner: {e}")
102+
exit_wrapper(99)
103+
104+
_entry()

0 commit comments

Comments
 (0)