Skip to content

Commit 2eed38c

Browse files
committed
Merge branch 'awarno/pydantic-config-validation' of https://github.com/NVIDIA-NeMo/Evaluator into awarno/pydantic-config-validation
2 parents 0e168c9 + d85da06 commit 2eed38c

File tree

13 files changed

+460
-216
lines changed

13 files changed

+460
-216
lines changed

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/package_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Below is the _next_ version that will be published, not the currently published one.
1717
MAJOR = 0
1818
MINOR = 1
19-
PATCH = 31
19+
PATCH = 32
2020
PRE_RELEASE = ""
2121

2222
# Use the following formatting: (major, minor, patch, pre-release)

packages/nemo-evaluator/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ repository = "https://github.com/NVIDIA-NeMo/Evaluator/packages/nemo-evaluator"
6565
# END(if-changed)
6666

6767
[dependency-groups]
68-
test = ["pytest", "pytest-cov", "pytest-subtests", "pytest-httpserver", "nvidia-simple-evals"]
68+
test = ["pytest", "pytest-asyncio", "pytest-cov", "pytest-subtests", "pytest-httpserver", "nvidia-simple-evals"]
6969

7070
docs = [
7171
"sphinx",

packages/nemo-evaluator/src/nemo_evaluator/adapters/adapter_config.py

Lines changed: 23 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -207,36 +207,19 @@ class AdapterConfig(BaseModel):
207207
description="Type of the endpoint to run the adapter for",
208208
default="chat",
209209
)
210-
caching_dir: str | None = Field(
211-
description="Directory for caching responses (legacy field)",
212-
default=None,
213-
)
214-
generate_html_report: bool = Field(
215-
description="Whether to generate HTML report (legacy field)",
216-
default=True,
217-
)
218210
log_failed_requests: bool = Field(
219-
description="Whether to log failed requests (legacy field)",
211+
description="Whether to log failed requests",
220212
default=False,
221213
)
222-
tracking_requests_stats: bool = Field(
223-
description="Whether to enable request statistics tracking. When enabled, response statistics including token usage, status codes, finish reasons, tool calls, and latency metrics will be collected and added to eval_factory_metrics.json for comprehensive evaluation analysis.",
224-
default=True,
225-
)
226-
html_report_size: int | None = Field(
227-
description="Number of request-response pairs to track in HTML report. If this is larger than max_saved_responses or max_saved_requests, it will override those values.",
228-
default=5,
229-
)
230214

231215
@classmethod
232216
def get_legacy_defaults(cls) -> dict[str, Any]:
233217
"""Get default values for legacy configuration parameters."""
234218
return {
235-
"generate_html_report": cls.model_fields["generate_html_report"].default,
236-
"html_report_size": cls.model_fields["html_report_size"].default,
237-
"tracking_requests_stats": cls.model_fields[
238-
"tracking_requests_stats"
239-
].default,
219+
"generate_html_report": True,
220+
"html_report_size": 5,
221+
"tracking_requests_stats": True,
222+
"caching_dir": None,
240223
"log_failed_requests": cls.model_fields["log_failed_requests"].default,
241224
"endpoint_type": cls.model_fields["endpoint_type"].default,
242225
# Boolean defaults for optional features
@@ -254,7 +237,6 @@ def get_legacy_defaults(cls) -> dict[str, Any]:
254237
"use_raise_client_errors": False,
255238
"include_json": True,
256239
"custom_system_prompt": None,
257-
"caching_dir": None,
258240
"output_dir": None,
259241
"params_to_add": None,
260242
"params_to_remove": None,
@@ -303,6 +285,24 @@ def merge_discovery(
303285
run_config.get("target", {}).get("api_endpoint", {}).get("adapter_config")
304286
)
305287

288+
# Validate that legacy parameters are not mixed with interceptors
289+
legacy_defaults = cls.get_legacy_defaults()
290+
model_fields = set(cls.model_fields.keys())
291+
legacy_only_params = set(legacy_defaults.keys()) - model_fields
292+
293+
for config_name, config in [
294+
("global_adapter_config", global_cfg),
295+
("target.api_endpoint.adapter_config", local_cfg),
296+
]:
297+
if config and config.get("interceptors"):
298+
found_legacy = [p for p in legacy_only_params if p in config]
299+
if found_legacy:
300+
raise ValueError(
301+
f"Cannot use legacy configuration parameters when interceptors are explicitly defined in {config_name}. "
302+
f"Found: {', '.join(sorted(found_legacy))}. "
303+
f"Please remove these and configure using interceptors instead."
304+
)
305+
306306
if not global_cfg and not local_cfg:
307307
# Create default adapter config with caching enabled by default
308308
return cls.from_legacy_config({}, run_config)
@@ -746,11 +746,7 @@ def from_legacy_config(
746746
interceptors=interceptors,
747747
post_eval_hooks=post_eval_hooks,
748748
endpoint_type=legacy_config["endpoint_type"],
749-
caching_dir=legacy_config["caching_dir"],
750-
generate_html_report=legacy_config["generate_html_report"],
751749
log_failed_requests=legacy_config["log_failed_requests"],
752-
tracking_requests_stats=legacy_config["tracking_requests_stats"],
753-
html_report_size=legacy_config["html_report_size"],
754750
)
755751

756752
def get_interceptor_configs(self) -> dict[str, dict[str, Any]]:

packages/nemo-evaluator/src/nemo_evaluator/adapters/interceptors/progress_tracking_interceptor.py

Lines changed: 59 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,8 @@
1818
import os
1919
import pathlib
2020
import threading
21-
from typing import Optional, final
21+
import time
22+
from typing import Annotated, Optional, final
2223

2324
import requests
2425
from pydantic import Field
@@ -48,10 +49,16 @@ class Params(BaseLoggingParams):
4849
default="http://localhost:8000",
4950
description="URL to post the number of processed samples to. Supports expansion of shell variables if present.",
5051
)
51-
progress_tracking_interval: int = Field(
52+
progress_tracking_interval: Annotated[int, Field(gt=0)] = Field(
5253
default=1,
5354
description="How often (every how many samples) to send a progress information.",
5455
)
56+
progress_tracking_interval_seconds: Optional[
57+
Annotated[float | None, Field(gt=0)]
58+
] = Field(
59+
default=None,
60+
description="How often (every N seconds) to send a progress information in addition to progress_tracking_interval.",
61+
)
5562
request_method: str = Field(
5663
default="PATCH",
5764
description="Request method to use for updating the evaluation progress.",
@@ -83,15 +90,30 @@ def __init__(self, params: Params):
8390
else:
8491
self.progress_filepath = None
8592
self._samples_processed = self._initialize_samples_processed()
93+
self._last_updated_samples_processed = self._samples_processed
8694
self._lock = threading.Lock()
8795

8896
# Get logger for this interceptor with interceptor context
8997
self.logger = get_logger(self.__class__.__name__)
9098

99+
# Optional update on timer
100+
self.progress_tracking_interval_seconds = (
101+
params.progress_tracking_interval_seconds
102+
)
103+
if self.progress_tracking_interval_seconds:
104+
self._timer_stopped = False
105+
self._update_on_timer_thread = threading.Thread(
106+
target=self._update_on_timer,
107+
kwargs={"interval_seconds": self.progress_tracking_interval_seconds},
108+
daemon=True,
109+
)
110+
self._update_on_timer_thread.start()
111+
91112
self.logger.info(
92113
"Progress tracking interceptor initialized",
93114
progress_tracking_url=self.progress_tracking_url,
94115
progress_tracking_interval=self.progress_tracking_interval,
116+
progress_tracking_interval_seconds=self.progress_tracking_interval_seconds,
95117
output_dir=str(self.progress_filepath) if self.progress_filepath else None,
96118
initial_samples_processed=self._samples_processed,
97119
)
@@ -151,6 +173,34 @@ def _send_progress(self, num_samples: int) -> requests.Response:
151173
samples_processed=num_samples,
152174
)
153175

176+
def _update_on_timer(self, interval_seconds: float):
177+
"""
178+
Sends an update on a timed interval if there has been a change since the last update.
179+
This is a blocking function that is expected to be executed in a thread.
180+
"""
181+
assert interval_seconds > 0
182+
while True:
183+
time.sleep(interval_seconds)
184+
with self._lock:
185+
if self._timer_stopped:
186+
return
187+
if self._last_updated_samples_processed == self._samples_processed:
188+
continue
189+
curr_samples = self._samples_processed
190+
191+
if self.progress_tracking_url is not None:
192+
self._send_progress(curr_samples)
193+
if self.progress_filepath is not None:
194+
self._write_progress(curr_samples)
195+
196+
self.logger.info(
197+
"Progress milestone updated on time interval",
198+
samples_processed=curr_samples,
199+
interval=self.progress_tracking_interval,
200+
)
201+
with self._lock:
202+
self._last_updated_samples_processed = curr_samples
203+
154204
@final
155205
def intercept_response(
156206
self, ar: AdapterResponse, context: AdapterGlobalContext
@@ -177,13 +227,20 @@ def intercept_response(
177227
samples_processed=curr_samples,
178228
interval=self.progress_tracking_interval,
179229
)
230+
with self._lock:
231+
self._last_updated_samples_processed = curr_samples
180232

181233
return ar
182234

183235
def post_eval_hook(self, context: AdapterGlobalContext) -> None:
184236
self.logger.info(
185237
"Post-eval hook executed", total_samples_processed=self._samples_processed
186238
)
239+
with self._lock:
240+
if self.progress_tracking_interval_seconds:
241+
self._timer_stopped = True
242+
if self._samples_processed == self._last_updated_samples_processed:
243+
return
187244

188245
if self.progress_tracking_url is not None:
189246
self._send_progress(self._samples_processed)

packages/nemo-evaluator/src/nemo_evaluator/package_info.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# Below is the _next_ version that will be published, not the currently published one.
1717
MAJOR = 0
1818
MINOR = 1
19-
PATCH = 29
19+
PATCH = 30
2020
PRE_RELEASE = ""
2121

2222
# Use the following formatting: (major, minor, patch, pre-release)

packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_progress_tracking_interceptor.py

Lines changed: 79 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -13,14 +13,14 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515

16+
import asyncio
1617
import os
1718
import threading
18-
import time
19-
from typing import List
2019
from unittest.mock import patch
2120

21+
import pytest
2222
import requests
23-
from flask import Flask, request
23+
from pydantic_core import ValidationError
2424

2525
from nemo_evaluator.adapters.interceptors.progress_tracking_interceptor import (
2626
ProgressTrackingInterceptor,
@@ -30,49 +30,7 @@
3030
AdapterRequestContext,
3131
AdapterResponse,
3232
)
33-
34-
35-
class FakeProgressTrackingServer:
36-
"""Test server to receive progress tracking webhooks."""
37-
38-
def __init__(self, port: int = 8000, request_method="PATCH"):
39-
self.port = port
40-
self.app = Flask(__name__)
41-
self.received_updates: List[dict] = []
42-
self.lock = threading.Lock()
43-
44-
@self.app.route("/", methods=[request_method])
45-
def progress_webhook():
46-
"""Receive progress updates."""
47-
data = request.get_json()
48-
with self.lock:
49-
self.received_updates.append(data)
50-
return {"status": "ok"}
51-
52-
def start(self):
53-
"""Start the server in a background thread."""
54-
self.thread = threading.Thread(
55-
target=self.app.run, kwargs={"host": "0.0.0.0", "port": self.port}
56-
)
57-
self.thread.daemon = True
58-
self.thread.start()
59-
# Give the server time to start
60-
time.sleep(0.5)
61-
62-
def stop(self):
63-
"""Stop the server."""
64-
# Flask doesn't have a clean shutdown, so we'll just let it run as daemon
65-
pass
66-
67-
def get_updates(self) -> List[dict]:
68-
"""Get all received updates."""
69-
with self.lock:
70-
return self.received_updates.copy()
71-
72-
def clear_updates(self):
73-
"""Clear received updates."""
74-
with self.lock:
75-
self.received_updates.clear()
33+
from tests.unit_tests.adapters.testing_utils import FakeProgressTrackingServer
7634

7735

7836
class TestProgressTrackingInterceptor:
@@ -255,6 +213,19 @@ def test_network_error_handling(self, mock_request):
255213
# Verify that the request was attempted
256214
mock_request.assert_called_once()
257215

216+
def test_interval_configuration_validation(self):
217+
with pytest.raises(ValidationError):
218+
ProgressTrackingInterceptor.Params(
219+
progress_tracking_url="http://test",
220+
progress_tracking_interval=0,
221+
)
222+
223+
with pytest.raises(ValidationError):
224+
ProgressTrackingInterceptor.Params(
225+
progress_tracking_url="http://test",
226+
progress_tracking_interval=-2,
227+
)
228+
258229
def test_interval_configuration(self):
259230
"""Test different interval configurations."""
260231
# Start test server
@@ -367,6 +338,68 @@ def test_configured_method(self):
367338
finally:
368339
server.stop()
369340

341+
def test_interval_timer_validation(self):
342+
with pytest.raises(ValidationError):
343+
ProgressTrackingInterceptor.Params(
344+
progress_tracking_interval_seconds=-1,
345+
)
346+
347+
@pytest.mark.asyncio
348+
async def test_interval_timer(self):
349+
# Start test server
350+
server = FakeProgressTrackingServer(port=8007)
351+
server.start()
352+
353+
try:
354+
params = ProgressTrackingInterceptor.Params(
355+
progress_tracking_url="http://localhost:8007",
356+
progress_tracking_interval=50,
357+
progress_tracking_interval_seconds=0.2,
358+
)
359+
interceptor = ProgressTrackingInterceptor(params)
360+
assert interceptor.progress_tracking_url == "http://localhost:8007"
361+
assert interceptor.progress_tracking_interval == 50
362+
assert interceptor.progress_tracking_interval_seconds == 0.2
363+
364+
# Create mock response and context
365+
mock_response = AdapterResponse(
366+
r=requests.Response(),
367+
rctx=AdapterRequestContext(),
368+
)
369+
context = AdapterGlobalContext(output_dir="/tmp", url="http://test")
370+
371+
# Verify no update until timer interval
372+
interceptor.intercept_response(mock_response, context)
373+
interceptor.intercept_response(mock_response, context)
374+
updates = server.get_updates()
375+
assert len(updates) == 0, "no updates until timer interval"
376+
377+
# Verify first timer interval calls update
378+
await asyncio.sleep(0.5)
379+
updates = server.get_updates()
380+
assert len(updates) == 1, "only expected one update"
381+
assert updates[0]["samples_processed"] == 2
382+
383+
# Verify subsequent timer interval calls update
384+
interceptor.intercept_response(mock_response, context)
385+
await asyncio.sleep(0.5)
386+
updates = server.get_updates()
387+
assert len(updates) == 2, "expected second update"
388+
assert updates[1]["samples_processed"] == 3
389+
390+
# No calls to update after timer is stopped
391+
interceptor.post_eval_hook(context)
392+
interceptor.intercept_response(mock_response, context)
393+
assert interceptor._samples_processed == 4
394+
await asyncio.sleep(0.5)
395+
updates = server.get_updates()
396+
assert len(updates) == 2, (
397+
"expected post_eval_hook to skip posting update on no change and no updates after post_eval_hook cancels timed updates"
398+
)
399+
400+
finally:
401+
server.stop()
402+
370403

371404
if __name__ == "__main__":
372405
# Simple test runner for manual testing

packages/nemo-evaluator/tests/unit_tests/adapters/interceptors/test_reasoning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def test_migration(
204204
url = f"http://{AdapterServer.DEFAULT_ADAPTER_HOST}:{adapter_server_migration.port}"
205205

206206
# Wait for server to be ready
207-
wait_for_server("localhost", 3825)
207+
wait_for_server("localhost", adapter_server_migration.port)
208208

209209
# We parametrize the response of the openai fake server.
210210
response_data = {

0 commit comments

Comments
 (0)