forked from PyCQA/pylint-pytest
-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathbase_tester.py
94 lines (72 loc) · 2.8 KB
/
base_tester.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from __future__ import annotations
import os
import sys
from abc import ABC
from pathlib import Path
from pprint import pprint
import astroid
from pylint.checkers import BaseChecker
from pylint.testutils import MessageTest, UnittestLinter
from pylint.utils import ASTWalker
import pylint_pytest.checkers.fixture
# XXX: allow all file names
pylint_pytest.checkers.fixture.FILE_NAME_PATTERNS = ("*",)
def get_test_root_path() -> Path:
"""Assumes ``base_tester.py`` is at ``<root>/tests``."""
return Path(__file__).parent
class BasePytestTester(ABC):
CHECKER_CLASS = BaseChecker
IMPACTED_CHECKER_CLASSES: list[BaseChecker] = []
MSG_ID: str
msgs: list[MessageTest] = []
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not hasattr(cls, "MSG_ID") or not isinstance(cls.MSG_ID, str) or not cls.MSG_ID:
raise TypeError("Subclasses must define a non-empty MSG_ID of type str")
enable_plugin = True
def run_linter(self, enable_plugin):
self.enable_plugin = enable_plugin
# pylint: disable-next=protected-access
target_test_file = sys._getframe(1).f_code.co_name.replace("test_", "", 1)
file_path = os.path.join(
get_test_root_path(),
"input",
self.MSG_ID,
target_test_file + ".py",
)
with open(file_path) as fin:
content = fin.read()
module = astroid.parse(content, module_name=target_test_file)
module.file = fin.name
self.walk(module) # run all checkers
self.msgs = self.linter.release_messages()
def verify_messages(self, msg_count, msg_id=None):
msg_id = msg_id or self.MSG_ID
matched_count = 0
for message in self.msgs:
# only care about ID and count, not the content
if message.msg_id == msg_id:
matched_count += 1
pprint(self.msgs)
assert matched_count == msg_count, f"expecting {msg_count}, actual {matched_count}"
def setup_method(self):
self.linter = UnittestLinter()
self.checker = self.CHECKER_CLASS(self.linter)
self.impacted_checkers = []
self.checker.open()
for checker_class in self.IMPACTED_CHECKER_CLASSES:
checker = checker_class(self.linter)
checker.open()
self.impacted_checkers.append(checker)
def teardown_method(self):
self.checker.close()
for checker in self.impacted_checkers:
checker.close()
def walk(self, node):
"""recursive walk on the given node"""
walker = ASTWalker(self.linter)
if self.enable_plugin:
walker.add_checker(self.checker)
for checker in self.impacted_checkers:
walker.add_checker(checker)
walker.walk(node)