Skip to content

Commit b6441f5

Browse files
Adds a Slack integration, and fixes import error in main. (#1186)
* First commit for Slack integration. updated code Added fake error * Fixed import error * Undid changes to script * update error message * Removed fake error * Handle Slack alert failures more gracefully (#1189) * Cleaned up PR. * Added back functions * Now, we disable slack alerts by default. * Added test
1 parent 733bf92 commit b6441f5

File tree

5 files changed

+71
-0
lines changed

5 files changed

+71
-0
lines changed

open_instruct/grpo_fast.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -419,6 +419,8 @@ class Args:
419419
"""multiply the gpus used for each oe-eval task"""
420420
eval_priority: Literal["low", "normal", "high", "urgent"] = "normal"
421421
"""the priority of auto-launched evaluation jobs"""
422+
send_slack_alerts: bool = False
423+
"""Whether to send Slack alerts on training failures"""
422424

423425
# Evaluation behavior
424426
eval_on_step_0: bool = False
@@ -3249,6 +3251,10 @@ def main(args: Args, tc: TokenizerConfig, model_config: ModelConfig):
32493251
model_dims,
32503252
checkpoint_state,
32513253
)
3254+
except Exception as e:
3255+
if args.send_slack_alerts:
3256+
utils.send_slack_alert(e)
3257+
raise
32523258
finally:
32533259
cleanup_training_resources(
32543260
stop_event, executor, [inference_results_Q, param_prompt_Q, evaluation_inference_results_Q], actor_manager

open_instruct/test_utils.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from unittest import mock
2020

2121
import pytest
22+
import responses
2223
import torch
2324
import vllm
2425
from dateutil import parser
@@ -240,6 +241,27 @@ def test_description_without_progress(self, mock_is_beaker_job, mock_beaker_from
240241
self.assertNotIn("% complete", desc)
241242

242243

244+
class TestSlackAlert(unittest.TestCase):
245+
@responses.activate
246+
@mock.patch("open_instruct.utils.get_beaker_experiment_url")
247+
@mock.patch("os.environ.get")
248+
def test_send_slack_alert_with_beaker_url(self, mock_environ_get, mock_get_beaker_url):
249+
webhook_url = "https://hooks.slack.com/services/test"
250+
mock_environ_get.return_value = webhook_url
251+
mock_get_beaker_url.return_value = "https://beaker.org/ex/test-123"
252+
253+
responses.add(responses.POST, webhook_url, json={"ok": True}, status=200)
254+
255+
test_error = ValueError("Test error message")
256+
utils.send_slack_alert(test_error)
257+
258+
self.assertEqual(len(responses.calls), 1)
259+
request_body = json.loads(responses.calls[0].request.body)
260+
self.assertIn("<!here> A RL job has died.", request_body["text"])
261+
self.assertIn("https://beaker.org/ex/test-123", request_body["text"])
262+
self.assertIn("Test error message", request_body["text"])
263+
264+
243265
class TestUtilityFunctions(unittest.TestCase):
244266
"""Test utility functions in utils module."""
245267

open_instruct/utils.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2433,3 +2433,29 @@ def combine_reward_metrics(reward_metrics: list[dict[str, Any]]) -> dict[str, An
24332433
# Fallback: keep the latest value if aggregation strategy is unclear.
24342434
combined[key] = records[-1]
24352435
return combined
2436+
2437+
2438+
def send_slack_alert(error: Exception) -> None:
2439+
"""Sends an alert about a training failure to a Slack webhook (if the env var SLACK_WEBHOOK is set)."""
2440+
slack_webhook_url = os.environ.get("SLACK_WEBHOOK")
2441+
if not slack_webhook_url:
2442+
logger.warning("SLACK_WEBHOOK environment variable not set. Skipping Slack alert.")
2443+
return
2444+
beaker_url = get_beaker_experiment_url()
2445+
beaker_message = f"Check it out: {beaker_url}. " if beaker_url else ""
2446+
message = f"<!here> A RL job has died. {beaker_message}Error message: {str(error)}."
2447+
payload = {"text": message}
2448+
response = requests.post(slack_webhook_url, json=payload)
2449+
if not response.ok:
2450+
logger.warning("Failed to send Slack alert with status %s: %s", response.status_code, response.text)
2451+
2452+
2453+
def get_beaker_experiment_url() -> str | None:
2454+
"""If the env var BEAKER_WORKLOAD_ID is set, gets the current experiment URL."""
2455+
try:
2456+
beaker_client = beaker.Beaker.from_env()
2457+
workload = beaker_client.workload.get(os.environ["BEAKER_WORKLOAD_ID"])
2458+
url = beaker_client.experiment.url(workload.experiment)
2459+
return url
2460+
except Exception:
2461+
return None

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ dev = [
8383
"ruff>=0.11.13",
8484
"parameterized>=0.9.0",
8585
"rich>=13.7.0",
86+
"responses>=0.25.8",
8687
]
8788

8889
[tool.pytest.ini_options]

uv.lock

Lines changed: 16 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)