diff --git a/cyaron/compare.py b/cyaron/compare.py index 3b4fec0..53ec9c5 100644 --- a/cyaron/compare.py +++ b/cyaron/compare.py @@ -1,14 +1,18 @@ from __future__ import absolute_import, print_function -from .io import IO -from . import log -from cyaron.utils import * -from cyaron.consts import * -from cyaron.graders import CYaRonGraders -import subprocess + import multiprocessing +import subprocess import sys +from concurrent.futures import ThreadPoolExecutor from io import open -import os +from typing import List, Optional, Tuple, Union, cast + +from cyaron.consts import * +from cyaron.graders import CYaRonGraders, GraderType3 +from cyaron.utils import * + +from . import log +from .io import IO class CompareMismatch(ValueError): @@ -22,11 +26,16 @@ def __str__(self): return "In program: '{}'. {}".format(self.name, self.mismatch) +PrgoramType = Union[str, Tuple[str, ...], List[str]] + + class Compare: @staticmethod - def __compare_two(name, content, std, grader): - (result, info) = CYaRonGraders.invoke(grader, content, std) + def __compare_two(name: PrgoramType, content: str, std: str, + input_content: str, grader: Union[str, GraderType3]): + result, info = CYaRonGraders.invoke(grader, content, std, + input_content) status = "Correct" if result else "!!!INCORRECT!!!" info = info if info is not None else "" log.debug("{}: {} {}".format(name, status, info)) @@ -34,13 +43,18 @@ def __compare_two(name, content, std, grader): raise CompareMismatch(name, info) @staticmethod - def __process_file(file): + def __process_output_file(file: Union[str, IO]): if isinstance(file, IO): + if file.output_filename is None: + raise ValueError("IO object has no output file.") file.flush_buffer() - file.output_file.seek(0) - return file.output_filename, file.output_file.read() + with open(file.output_filename, + "r", + newline="\n", + encoding='utf-8') as f: + return file.output_filename, f.read() else: - with open(file, "r", newline="\n") as f: + with open(file, "r", newline="\n", encoding="utf-8") as f: return file, f.read() @staticmethod @@ -64,7 +78,7 @@ def output(cls, *files, **kwargs): ("stop_on_incorrect", None), ), ) - std = kwargs["std"] + std: IO = kwargs["std"] grader = kwargs["grader"] max_workers = kwargs["max_workers"] job_pool = kwargs["job_pool"] @@ -75,8 +89,6 @@ def output(cls, *files, **kwargs): if (max_workers is None or max_workers >= 0) and job_pool is None: max_workers = cls.__normal_max_workers(max_workers) try: - from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor(max_workers=max_workers) as job_pool: return cls.output(*files, std=std, @@ -87,16 +99,21 @@ def output(cls, *files, **kwargs): pass def get_std(): - return cls.__process_file(std)[1] + return cls.__process_output_file(std)[1] if job_pool is not None: - std = job_pool.submit(get_std).result() + std_answer = job_pool.submit(get_std).result() else: - std = get_std() + std_answer = get_std() + + with open(std.input_filename, "r", newline="\n", + encoding="utf-8") as input_file: + input_text = input_file.read() def do(file): - (file_name, content) = cls.__process_file(file) - cls.__compare_two(file_name, content, std, grader) + (file_name, content) = cls.__process_output_file(file) + cls.__compare_two(file_name, content, std_answer, input_text, + grader) if job_pool is not None: job_pool.map(do, files) @@ -104,35 +121,36 @@ def do(file): [x for x in map(do, files)] @classmethod - def program(cls, *programs, **kwargs): - kwargs = unpack_kwargs( - "program", - kwargs, - ( - "input", - ("std", None), - ("std_program", None), - ("grader", DEFAULT_GRADER), - ("max_workers", -1), - ("job_pool", None), - ("stop_on_incorrect", None), - ), - ) - input = kwargs["input"] - std = kwargs["std"] - std_program = kwargs["std_program"] - grader = kwargs["grader"] - max_workers = kwargs["max_workers"] - job_pool = kwargs["job_pool"] - if kwargs["stop_on_incorrect"] is not None: + def program(cls, + *programs: Union[PrgoramType, Tuple[PrgoramType, float]], + input: Union[IO, str], + std: Optional[Union[str, IO]] = None, + std_program: Optional[Union[str, Tuple[str, ...], + List[str]]] = None, + grader: Union[str, GraderType3] = DEFAULT_GRADER, + max_workers: Optional[int] = -1, + job_pool: Optional[ThreadPoolExecutor] = None, + stop_on_incorrect=None): + """ + Compare the output of the programs with the standard output. + + Args: + programs: The programs to be compared. + input: The input file. + std: The standard output file. + std_program: The program that generates the standard output. + grader: The grader to be used. + max_workers: The maximum number of workers. + job_pool: The job pool. + stop_on_incorrect: Deprecated and has no effect. + """ + if stop_on_incorrect is not None: log.warn( "parameter stop_on_incorrect is deprecated and has no effect.") if (max_workers is None or max_workers >= 0) and job_pool is None: max_workers = cls.__normal_max_workers(max_workers) try: - from concurrent.futures import ThreadPoolExecutor - with ThreadPoolExecutor(max_workers=max_workers) as job_pool: return cls.program(*programs, input=input, @@ -144,74 +162,70 @@ def program(cls, *programs, **kwargs): except ImportError: pass - if not isinstance(input, IO): - raise TypeError("expect {}, got {}".format( - type(IO).__name__, - type(input).__name__)) - input.flush_buffer() - input.input_file.seek(0) + if isinstance(input, IO): + input.flush_buffer() if std_program is not None: - def get_std(): - with open(os.dup(input.input_file.fileno()), "r", - newline="\n") as input_file: - content = make_unicode( - subprocess.check_output( - std_program, - shell=(not list_like(std_program)), - stdin=input.input_file, - universal_newlines=True, - )) - input_file.seek(0) + def get_std_from_std_program(): + with open(input.input_filename + if isinstance(input, IO) else input, + "r", + newline="\n", + encoding="utf-8") as input_file: + content = subprocess.check_output( + std_program, + shell=(not list_like(std_program)), + stdin=input_file, + universal_newlines=True, + encoding="utf-8") return content if job_pool is not None: - std = job_pool.submit(get_std).result() + std = job_pool.submit(get_std_from_std_program).result() else: - std = get_std() + std = get_std_from_std_program() elif std is not None: - def get_std(): - return cls.__process_file(std)[1] + def get_std_from_std_file(): + return cls.__process_output_file(cast(Union[str, IO], std))[1] if job_pool is not None: - std = job_pool.submit(get_std).result() + std = job_pool.submit(get_std_from_std_file).result() else: - std = get_std() + std = get_std_from_std_file() else: raise TypeError( "program() missing 1 required non-None keyword-only argument: 'std' or 'std_program'" ) - def do(program_name): + with open(input.input_filename if isinstance(input, IO) else input, + "r", + newline="\n", + encoding="utf-8") as input_file: + input_text = input_file.read() + + def do(program_name: Union[PrgoramType, Tuple[PrgoramType, float]]): timeout = None - if (list_like(program_name) and len(program_name) == 2 - and int_like(program_name[-1])): - program_name, timeout = program_name - with open(os.dup(input.input_file.fileno()), "r", - newline="\n") as input_file: - if timeout is None: - content = make_unicode( - subprocess.check_output( - program_name, - shell=(not list_like(program_name)), - stdin=input_file, - universal_newlines=True, - )) - else: - content = make_unicode( - subprocess.check_output( - program_name, - shell=(not list_like(program_name)), - stdin=input_file, - universal_newlines=True, - timeout=timeout, - )) - input_file.seek(0) - cls.__compare_two(program_name, content, std, grader) + if isinstance(program_name, tuple) and len(program_name) == 2 and ( + isinstance(program_name[1], float) + or isinstance(program_name[1], int)): + program_name, timeout = cast(Tuple[PrgoramType, float], + program_name) + else: + program_name = cast(PrgoramType, program_name) + content = subprocess.check_output( + list(program_name) + if isinstance(program_name, tuple) else program_name, + shell=(not list_like(program_name)), + input=input_text, + universal_newlines=True, + encoding="utf-8", + timeout=timeout) + cls.__compare_two(program_name, content, std, input_text, grader) if job_pool is not None: job_pool.map(do, programs) else: - [x for x in map(do, programs)] + for program in programs: + do(program) diff --git a/cyaron/graders/__init__.py b/cyaron/graders/__init__.py index b11a375..486c730 100644 --- a/cyaron/graders/__init__.py +++ b/cyaron/graders/__init__.py @@ -1,4 +1,5 @@ -from .graderregistry import CYaRonGraders +from .graderregistry import CYaRonGraders, GraderType2, GraderType3 from .fulltext import fulltext -from .noipstyle import noipstyle \ No newline at end of file +from .noipstyle import noipstyle +from .testlib_checker import TestlibChecker diff --git a/cyaron/graders/graderregistry.py b/cyaron/graders/graderregistry.py index 659f239..1a65f32 100644 --- a/cyaron/graders/graderregistry.py +++ b/cyaron/graders/graderregistry.py @@ -1,18 +1,54 @@ +from typing import Callable, Tuple, Dict, Union, Any + +__all__ = ['CYaRonGraders', 'GraderType2', 'GraderType3'] + +GraderType2 = Callable[[str, str], Tuple[bool, Any]] +GraderType3 = Callable[[str, str, str], Tuple[bool, Any]] + + class GraderRegistry: - _registry = dict() + """A registry for grader functions.""" + _registry: Dict[str, GraderType3] = {} + + def grader2(self, name: str): + """ + This decorator registers a grader function under a specific name in the registry. + + The function being decorated should accept exactly two parameters (excluding + the content input). + """ + + def wrapper(func: GraderType2): + self._registry[name] = lambda content, std, _: func(content, std) + return func + + return wrapper + + grader = grader2 - def grader(self, name): + def grader3(self, name: str): + """ + This decorator registers a grader function under a specific name in the registry. + + The function being decorated should accept exactly three parameters. + """ - def wrapper(func): + def wrapper(func: GraderType3): self._registry[name] = func return func return wrapper - def invoke(self, name, content, std): - return self._registry[name](content, std) + def invoke(self, grader: Union[str, GraderType3], content: str, std: str, + input_content: str): + """Invoke a grader function by name or function object.""" + if isinstance(grader, str): + return self._registry[grader](content, std, input_content) + else: + return grader(content, std, input_content) - def check(self, name): + def check(self, name: str): + """Check if a grader is registered.""" return name in self._registry diff --git a/cyaron/graders/testlib_checker.py b/cyaron/graders/testlib_checker.py new file mode 100644 index 0000000..513e7d5 --- /dev/null +++ b/cyaron/graders/testlib_checker.py @@ -0,0 +1,66 @@ +import tempfile +import subprocess +import xml.etree.ElementTree as xmlElementTree +from typing import Optional +from os.path import join as path_join + +__all__ = ["TestlibChecker"] + + +class TestlibCheckerResult: + + def __init__(self, result: Optional[str], outcome: str, + pctype: Optional[str]): + self.result = result + self.outcome = outcome + self.pctype = pctype + + def __str__(self): + return ' '.join([self.outcome] + + ([] if self.pctype is None else [f'({self.pctype})']) + + ([] if self.result is None else [self.result])) + + +class TestlibChecker: + """ + A grader that uses the testlib checker. + """ + + def __init__(self, checker_path: str): + self.checker_path = checker_path + + def __call__(self, outs: str, ans: str, ins: str): + with tempfile.NamedTemporaryFile('w') as inf, \ + tempfile.NamedTemporaryFile('w') as outf, \ + tempfile.NamedTemporaryFile('w') as ansf, \ + tempfile.TemporaryDirectory() as checker_output_dir: + inf.write(ins) + outf.write(outs) + ansf.write(ans) + inf.flush() + outf.flush() + ansf.flush() + checker_output_file = path_join(checker_output_dir, + 'checker_output.xml') + + result = subprocess.run((self.checker_path, inf.name, outf.name, + ansf.name, checker_output_file, '-appes'), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + check=False) + if result.stderr.strip() != 'See file to check exit message': + raise ValueError("Invalid output from checker: " + + result.stderr) + + result_element = xmlElementTree.parse( + checker_output_file).getroot() + if result_element.tag != 'result': + raise ValueError("Invalid output from checker") + result_text = result_element.text + result_outcome = result_element.get('outcome') + if result_outcome is None: + raise ValueError("Invalid output from checker") + result_pctype = result_element.get('pctype') + return result_outcome == 'accepted', TestlibCheckerResult( + result_text, result_outcome, result_pctype) diff --git a/cyaron/io.py b/cyaron/io.py index 61b5d4e..5e186b1 100644 --- a/cyaron/io.py +++ b/cyaron/io.py @@ -100,7 +100,7 @@ def __init__( # type: ignore output_file = "{}{{}}{}".format( self.__escape_format(file_prefix), self.__escape_format(output_suffix)) - self.input_filename, self.output_filename = None, None + self.input_filename, self.output_filename = cast(str, None), None self.__input_temp, self.__output_temp = False, False self.__init_file(input_file, data_id, "i", make_dirs) if not disable_output: @@ -357,3 +357,5 @@ def output_clear_content(self, pos: int = 0): def flush_buffer(self): """Flush the input file""" self.input_file.flush() + if self.output_file: + self.output_file.flush() diff --git a/cyaron/tests/compare_test.py b/cyaron/tests/compare_test.py index bc0a830..ca1e6c1 100644 --- a/cyaron/tests/compare_test.py +++ b/cyaron/tests/compare_test.py @@ -8,6 +8,7 @@ from cyaron.output_capture import captured_output from cyaron.graders.mismatch import * from cyaron.compare import CompareMismatch +from cyaron.graders import CYaRonGraders log.set_verbose() @@ -108,28 +109,36 @@ def test_fulltext_program(self): correct_out = f'{sys.executable} correct.py: Correct \n{sys.executable} incorrect.py: !!!INCORRECT!!! Hash mismatch: read 53c234e5e8472b6ac51c1ae1cab3fe06fad053beb8ebfd8977b010655bfdd3c3, expected 4355a46b19d348dc2f57c046f8ef63d4538ebb936000f3c9ee954a27460dd865' self.assertEqual(result, correct_out) - def test_file_input(self): + def test_file_input_success(self): with open("correct.py", "w") as f: f.write("print(input())") - with open("std.py", "w") as f: f.write("print(input())") - - io = None - with captured_output() as (out, err): - io = IO() - + io = IO() io.input_writeln("233") - - with captured_output() as (out, err): - Compare.program(f"{sys.executable} correct.py", - std_program=f"{sys.executable} std.py", + with captured_output(): + Compare.program((sys.executable, "correct.py"), + std_program=(sys.executable, "std.py"), input=io, grader="NOIPStyle") - result = out.getvalue().strip() - correct_out = f'{sys.executable} correct.py: Correct' - self.assertEqual(result, correct_out) + def test_file_input_fail(self): + with open("incorrect.py", "w") as f: + f.write("print(input()+'154')") + with open("std.py", "w") as f: + f.write("print(input())") + io = IO() + io.input_writeln("233") + try: + with captured_output(): + Compare.program((sys.executable, "incorrect.py"), + std_program=(sys.executable, "std.py"), + input=io, + grader="NOIPStyle") + except CompareMismatch: + pass + else: + self.fail("Should raise CompareMismatch") def test_concurrent(self): programs = ['test{}.py'.format(i) for i in range(16)] @@ -168,3 +177,80 @@ def test_timeout(self): pass else: self.assertTrue(False) + + def test_custom_grader2_by_name(self): + self.assertEqual(CYaRonGraders.grader, CYaRonGraders.grader2) + + @CYaRonGraders.grader("CustomTestGrader2") + def custom_test_grader2(content: str, std: str): + if content == '1\n' and std == '2\n': + return True, None + return False, "CustomTestGrader failed" + + io = IO() + io.output_writeln("2") + + Compare.program("echo 1", + std=io, + input=IO(), + grader="CustomTestGrader2") + + try: + Compare.program("echo 2", + std=io, + input=IO(), + grader="CustomTestGrader2") + except CompareMismatch as e: + self.assertEqual(e.name, 'echo 2') + self.assertEqual(e.mismatch, "CustomTestGrader failed") + else: + self.fail("Should raise CompareMismatch") + + def test_custom_grader3_by_name(self): + + @CYaRonGraders.grader3("CustomTestGrader3") + def custom_test_grader3(content: str, std: str, input_content: str): + if input_content == '0\n' and content == '1\n' and std == '2\n': + return True, None + return False, "CustomTestGrader failed" + + io = IO() + io.input_writeln("0") + io.output_writeln("2") + + Compare.program("echo 1", std=io, input=io, grader="CustomTestGrader3") + + try: + Compare.program("echo 2", + std=io, + input=io, + grader='CustomTestGrader3') + except CompareMismatch as e: + self.assertEqual(e.name, 'echo 2') + self.assertEqual(e.mismatch, "CustomTestGrader failed") + else: + self.fail("Should raise CompareMismatch") + + def test_custom_grader_by_function(self): + + def custom_test_grader(content: str, std: str, input_content: str): + if input_content == '0\n' and content == '1\n' and std == '2\n': + return True, None + return False, "CustomTestGrader failed" + + io = IO() + io.input_writeln("0") + io.output_writeln("2") + + Compare.program("echo 1", std=io, input=io, grader=custom_test_grader) + + try: + Compare.program("echo 2", + std=io, + input=io, + grader=custom_test_grader) + except CompareMismatch as e: + self.assertEqual(e.name, 'echo 2') + self.assertEqual(e.mismatch, "CustomTestGrader failed") + else: + self.fail("Should raise CompareMismatch")