Skip to content

Commit b0fd50a

Browse files
committed
fix: add gvar caching improved
1 parent 82458a6 commit b0fd50a

14 files changed

Lines changed: 326 additions & 27 deletions

File tree

.vscode/settings.json

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
{
2+
"python.testing.pytestArgs": [
3+
"tests"
4+
],
5+
"python.testing.unittestEnabled": false,
6+
"python.testing.pytestEnabled": true
7+
}

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "hatchling.build"
44

55
[project]
66
name = "avrae-ls"
7-
version = "0.8.4"
7+
version = "0.8.5"
88
description = "Language server for Avrae draconic aliases"
99
authors = [
1010
{ name = "1drturtle" }

src/avrae_ls/__main__.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,8 @@
1212

1313
from lsprotocol import types
1414

15-
from avrae_ls.testing.alias_tests import (
16-
AliasTestError,
17-
AliasTestResult,
18-
diff_mismatched_parts,
19-
discover_test_files,
20-
parse_alias_tests,
21-
run_alias_tests,
22-
)
15+
from avrae_ls.testing.alias_tests import AliasTestError, AliasTestResult, discover_test_files, parse_alias_tests, run_alias_tests
16+
from avrae_ls.testing._common import diff_mismatched_parts
2317
from avrae_ls.testing.gvar_tests import GVarTestError, GVarTestResult, parse_gvar_tests, run_gvar_tests
2418
from avrae_ls.config import AvraeLSConfig, CONFIG_FILENAME, load_config
2519
from avrae_ls.runtime.context import ContextBuilder
@@ -165,6 +159,7 @@ def _run_alias_tests(
165159

166160
builder = ContextBuilder(config)
167161
executor = MockExecutor(config.service)
162+
shared_gvar_cache = dict(builder.build_baseline().vars.gvars)
168163

169164
test_files = discover_test_files(target, patterns=RUN_TEST_PATTERNS)
170165
alias_cases = []
@@ -186,8 +181,16 @@ def _run_alias_tests(
186181
print(f"No alias or gvar tests found under {target}")
187182
return 1 if parse_errors else 0
188183

189-
alias_results = asyncio.run(run_alias_tests(alias_cases, builder, executor)) if alias_cases else []
190-
gvar_results = asyncio.run(run_gvar_tests(gvar_cases, builder, executor)) if gvar_cases else []
184+
alias_results = (
185+
asyncio.run(run_alias_tests(alias_cases, builder, executor, suite_gvar_cache=shared_gvar_cache))
186+
if alias_cases
187+
else []
188+
)
189+
gvar_results = (
190+
asyncio.run(run_gvar_tests(gvar_cases, builder, executor, suite_gvar_cache=shared_gvar_cache))
191+
if gvar_cases
192+
else []
193+
)
191194
results = [*alias_results, *gvar_results]
192195
_print_test_results(results, workspace_root)
193196

src/avrae_ls/runtime/context.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,7 @@ def _handle_gvar_response(self, key: str, resp: httpx.Response, *, blocking: boo
269269
return True
270270

271271
def reset(self, gvars: Dict[str, Any] | None = None) -> None:
272-
self._cache = {}
273-
if gvars:
274-
self._cache.update({str(k): v for k, v in gvars.items()})
272+
self.load_snapshot(gvars)
275273

276274
def seed(self, gvars: Dict[str, Any] | None = None) -> None:
277275
"""Merge provided gvars into the cache without dropping fetched values."""
@@ -359,6 +357,11 @@ def ensure_blocking(self, key: str) -> bool:
359357
def snapshot(self) -> Dict[str, Any]:
360358
return dict(self._cache)
361359

360+
def load_snapshot(self, gvars: Dict[str, Any] | None = None) -> None:
361+
self._cache = {}
362+
if gvars:
363+
self._cache.update({str(k): v for k, v in gvars.items()})
364+
362365
async def refresh(self, seed: Dict[str, Any] | None = None, keys: Iterable[str] | None = None) -> Dict[str, Any]:
363366
self.reset(seed)
364367
if keys:

src/avrae_ls/testing/_common.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,17 @@ def diff_mismatched_parts(expected: Any, actual: Any) -> tuple[Any, Any] | None:
111111
return expected, actual
112112

113113

114+
def merge_new_gvars_into_suite_cache(
115+
suite_cache: dict[str, Any], active_cache: dict[str, Any], *, exclude_keys: set[str] | None = None
116+
) -> None:
117+
excluded = {str(key) for key in (exclude_keys or set())}
118+
for key, value in active_cache.items():
119+
key_str = str(key)
120+
if key_str in excluded or key_str in suite_cache:
121+
continue
122+
suite_cache[key_str] = value
123+
124+
114125
def compile_expected_pattern(text: str) -> re.Pattern[str] | None:
115126
"""
116127
Interpret strings with /.../ segments (or re:prefix) as regex.

src/avrae_ls/testing/alias_tests.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from avrae_ls.testing._common import (
1515
deep_merge_dicts,
1616
dict_matches,
17+
merge_new_gvars_into_suite_cache,
1718
parse_expected_value,
1819
parse_metadata_mapping,
1920
scalar_matches,
@@ -142,7 +143,11 @@ def parse_alias_tests(path: Path) -> list[AliasTestCase]:
142143

143144

144145
async def run_alias_tests(
145-
cases: Iterable[AliasTestCase], builder: ContextBuilder, executor: MockExecutor
146+
cases: Iterable[AliasTestCase],
147+
builder: ContextBuilder,
148+
executor: MockExecutor,
149+
*,
150+
suite_gvar_cache: dict[str, Any] | None = None,
146151
) -> list[AliasTestResult]:
147152
case_list = list(cases)
148153
alias_sources: dict[Path, str] = {}
@@ -163,6 +168,7 @@ async def run_alias_tests(
163168
)
164169

165170
baseline = builder.build_baseline()
171+
shared_gvar_cache = suite_gvar_cache if suite_gvar_cache is not None else dict(baseline.vars.gvars)
166172
results: list[AliasTestResult] = []
167173
for case in case_list:
168174
error = alias_errors.get(case.alias_path)
@@ -176,6 +182,7 @@ async def run_alias_tests(
176182
executor,
177183
alias_source=alias_sources.get(case.alias_path),
178184
base_context=baseline,
185+
suite_gvar_cache=shared_gvar_cache,
179186
)
180187
)
181188
return results
@@ -188,6 +195,7 @@ async def run_alias_test(
188195
*,
189196
alias_source: str | None = None,
190197
base_context: ContextData | None = None,
198+
suite_gvar_cache: dict[str, Any] | None = None,
191199
) -> AliasTestResult:
192200
if alias_source is None:
193201
source_started = time.perf_counter()
@@ -212,9 +220,17 @@ async def run_alias_test(
212220
ctx_data.vars = ctx_data.vars.merge(VarSources.from_data(case.var_overrides))
213221
if case.character_overrides:
214222
ctx_data.character = deep_merge_dicts(ctx_data.character, case.character_overrides)
215-
builder.gvar_resolver.reset(ctx_data.vars.gvars)
223+
shared_gvar_cache = suite_gvar_cache if suite_gvar_cache is not None else dict(ctx_data.vars.gvars)
224+
builder.gvar_resolver.load_snapshot(shared_gvar_cache)
225+
builder.gvar_resolver.seed(ctx_data.vars.gvars)
226+
local_only_gvars = set(ctx_data.vars.gvars.keys())
216227

217228
rendered = await render_alias_command(alias_source, executor, ctx_data, builder.gvar_resolver, args=case.args)
229+
merge_new_gvars_into_suite_cache(
230+
shared_gvar_cache,
231+
builder.gvar_resolver.snapshot(),
232+
exclude_keys=local_only_gvars,
233+
)
218234
if rendered.error:
219235
return AliasTestResult(
220236
case=case,

src/avrae_ls/testing/gvar_tests.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,13 @@
1010
from avrae_ls.gvar_utils import sanitize_gvar_binding
1111
from avrae_ls.runtime.context import ContextBuilder, ContextData
1212
from avrae_ls.runtime.runtime import MockExecutor, ModuleExecutionError
13-
from avrae_ls.testing._common import deep_merge_dicts, parse_expected_value, parse_metadata_mapping, value_matches
13+
from avrae_ls.testing._common import (
14+
deep_merge_dicts,
15+
merge_new_gvars_into_suite_cache,
16+
parse_expected_value,
17+
parse_metadata_mapping,
18+
value_matches,
19+
)
1420

1521
log = logging.getLogger(__name__)
1622

@@ -124,7 +130,11 @@ def parse_gvar_tests(path: Path) -> list[GVarTestCase]:
124130

125131

126132
async def run_gvar_tests(
127-
cases: Iterable[GVarTestCase], builder: ContextBuilder, executor: MockExecutor
133+
cases: Iterable[GVarTestCase],
134+
builder: ContextBuilder,
135+
executor: MockExecutor,
136+
*,
137+
suite_gvar_cache: dict[str, Any] | None = None,
128138
) -> list[GVarTestResult]:
129139
case_list = list(cases)
130140
gvar_sources: dict[Path, str] = {}
@@ -141,6 +151,7 @@ async def run_gvar_tests(
141151
log.debug("Loaded %d gvar source file(s) for tests in %.2fms", len(gvar_sources), log_elapsed)
142152

143153
baseline = builder.build_baseline()
154+
shared_gvar_cache = suite_gvar_cache if suite_gvar_cache is not None else dict(baseline.vars.gvars)
144155
results: list[GVarTestResult] = []
145156
for case in case_list:
146157
error = gvar_errors.get(case.gvar_path)
@@ -154,6 +165,7 @@ async def run_gvar_tests(
154165
executor,
155166
gvar_source=gvar_sources.get(case.gvar_path),
156167
base_context=baseline,
168+
suite_gvar_cache=shared_gvar_cache,
157169
)
158170
)
159171
return results
@@ -166,6 +178,7 @@ async def run_gvar_test(
166178
*,
167179
gvar_source: str | None = None,
168180
base_context: ContextData | None = None,
181+
suite_gvar_cache: dict[str, Any] | None = None,
169182
) -> GVarTestResult:
170183
if gvar_source is None:
171184
source_started = time.perf_counter()
@@ -186,15 +199,24 @@ async def run_gvar_test(
186199
ctx_data.vars = ctx_data.vars.merge(VarSources.from_data(case.var_overrides))
187200
if case.character_overrides:
188201
ctx_data.character = deep_merge_dicts(ctx_data.character, case.character_overrides)
189-
builder.gvar_resolver.reset(ctx_data.vars.gvars)
202+
shared_gvar_cache = suite_gvar_cache if suite_gvar_cache is not None else dict(ctx_data.vars.gvars)
203+
builder.gvar_resolver.load_snapshot(shared_gvar_cache)
204+
builder.gvar_resolver.seed(ctx_data.vars.gvars)
205+
local_only_gvars = set(ctx_data.vars.gvars.keys())
190206

191207
collision = _reserved_name_collision(case.binding_name, executor, ctx_data)
192208
if collision is not None:
193209
return GVarTestResult(case=case, passed=False, actual=None, stdout="", error=collision)
194210

195211
builder.gvar_resolver.seed({case.gvar_name: gvar_source})
212+
local_only_gvars.add(case.gvar_name)
196213
wrapped_code = _wrap_test_body(case.binding_name, case.gvar_name, case.body)
197214
result = await executor.run(wrapped_code, ctx_data, builder.gvar_resolver)
215+
merge_new_gvars_into_suite_cache(
216+
shared_gvar_cache,
217+
builder.gvar_resolver.snapshot(),
218+
exclude_keys=local_only_gvars,
219+
)
198220
if result.error:
199221
error_line, error_col = _map_error_position(result.error, wrapper_lines=1)
200222
return GVarTestResult(

tests/test_alias_tests.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,3 +355,50 @@ async def test_alias_tests_allow_metadata_vars_override(tmp_path):
355355
case = parse_alias_tests(test_path)[0]
356356
result = (await run_alias_tests([case], builder, executor))[0]
357357
assert result.passed
358+
359+
360+
@pytest.mark.asyncio
361+
async def test_alias_tests_reuse_nested_gvar_fetches_across_cases(monkeypatch, tmp_path):
362+
alias_path = tmp_path / "outer.alias"
363+
alias_path.write_text('!alias outer echo <drac2>using(mod="remote")\nreturn mod.answer</drac2>')
364+
test_path = tmp_path / "outer.alias-test"
365+
test_path.write_text('!outer\n---\n"42"\n\n!outer\n---\n"42"\n')
366+
367+
config = AvraeLSConfig.default(tmp_path)
368+
config.enable_gvar_fetch = True
369+
config.service.token = "token"
370+
builder = ContextBuilder(config)
371+
executor = MockExecutor(config.service)
372+
373+
calls: list[str] = []
374+
375+
class DummyResponse:
376+
status_code = 200
377+
378+
def __init__(self, key: str):
379+
self.key = key
380+
381+
def json(self):
382+
return {"value": "answer = 42\n"}
383+
384+
class DummyClient:
385+
def __init__(self, **kwargs):
386+
pass
387+
388+
async def __aenter__(self):
389+
return self
390+
391+
async def __aexit__(self, exc_type, exc, tb):
392+
return False
393+
394+
async def get(self, url, headers=None):
395+
calls.append(url.rsplit("/", 1)[-1])
396+
return DummyResponse(calls[-1])
397+
398+
monkeypatch.setattr("avrae_ls.runtime.context.httpx.AsyncClient", DummyClient)
399+
400+
cases = parse_alias_tests(test_path)
401+
results = await run_alias_tests(cases, builder, executor)
402+
403+
assert all(result.passed for result in results)
404+
assert calls == ["remote"]

0 commit comments

Comments
 (0)