diff --git a/examples/retool/generate_with_retool.py b/examples/retool/generate_with_retool.py index a090af63d7..a1b9c4eee0 100644 --- a/examples/retool/generate_with_retool.py +++ b/examples/retool/generate_with_retool.py @@ -9,6 +9,7 @@ from slime.rollout.sglang_rollout import GenerateState from slime.utils.http_utils import post +from slime.utils import logging_utils from slime.utils.types import Sample # Import reward models @@ -248,26 +249,28 @@ async def generate(args, sample: Sample, sampling_params) -> Sample: "return_logprob": True, # Request log probabilities for training } - # Log payload to wandb for debugging + # Log payload for debugging try: - import wandb - - if wandb.run is not None: + if getattr(args, "use_wandb", False) or getattr(args, "use_swanlab", False) or getattr( + args, "use_tensorboard", False + ): # Count available tools (from tool_specs) available_tools = len(tool_specs) # Count tools used in the current response tools_used = response.count("") - wandb.log( + logging_utils.log( + args, { "debug/payload_length": len(prompt + response), "debug/available_tools": available_tools, "debug/tools_used": tools_used, "debug/turn": turn, - } + }, + step_key="debug/turn", ) - except ImportError: - pass # wandb not available + except Exception: + pass output = await post(url, payload) diff --git a/pyproject.toml b/pyproject.toml index 6f9ee242a5..b057d61190 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,7 +15,7 @@ sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] default_section = "THIRDPARTY" extend_skip = ["setup.py", "docs/source/conf.py"] known_first_party = ["slime", "slime_plugins"] -known_third_party = ["megatron", "wandb", "ray", "transformers"] +known_third_party = ["megatron", "ray", "swanlab", "transformers", "wandb"] src_paths = ["slime", "slime_plugins"] diff --git a/requirements.txt b/requirements.txt index 427680d168..24038d36a3 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ qwen_vl_utils # for VLM ray[default] ring_flash_attn sglang-router>=0.2.3 +swanlab tensorboard transformers wandb diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index e8a1730782..99f013da9e 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1095,6 +1095,40 @@ def add_wandb_arguments(parser): parser.add_argument("--wandb-run-id", type=str, default=None) return parser + # swanlab + def add_swanlab_arguments(parser): + parser.add_argument("--use-swanlab", action="store_true", default=False) + parser.add_argument( + "--swanlab-mode", + type=str, + default=None, + choices=["cloud", "local", "offline", "disabled"], + help="SwanLab mode: cloud (default), local, offline, or disabled.", + ) + parser.add_argument( + "--swanlab-dir", + type=str, + default=None, + help="Directory to store SwanLab logs. Default is swanlog in the current directory.", + ) + parser.add_argument("--swanlab-key", type=str, default=None) + parser.add_argument("--swanlab-host", type=str, default=None) + parser.add_argument("--swanlab-web-host", type=str, default=None) + parser.add_argument("--swanlab-workspace", type=str, default=None) + parser.add_argument("--swanlab-project", type=str, default=None) + parser.add_argument("--swanlab-group", type=str, default=None) + parser.add_argument("--swanlab-experiment-name", type=str, default=None) + parser.add_argument("--swanlab-open-metrics-interval", type=int, default=10) + parser.add_argument( + "--disable-swanlab-random-suffix", + action="store_false", + dest="swanlab_random_suffix", + default=True, + help="Whether to add a random suffix to the SwanLab group name.", + ) + parser.add_argument("--swanlab-run-id", type=str, default=None) + return parser + # tensorboard def add_tensorboard_arguments(parser): # tb_project_name, tb_experiment_name @@ -1379,6 +1413,7 @@ def add_ci_arguments(parser): parser = add_algo_arguments(parser) parser = add_on_policy_distillation_arguments(parser) parser = add_wandb_arguments(parser) + parser = add_swanlab_arguments(parser) parser = add_tensorboard_arguments(parser) parser = add_router_arguments(parser) parser = add_debug_arguments(parser) diff --git a/slime/utils/external_utils/command_utils.py b/slime/utils/external_utils/command_utils.py index cdf2cbe0b0..0d98ba9b8b 100644 --- a/slime/utils/external_utils/command_utils.py +++ b/slime/utils/external_utils/command_utils.py @@ -189,8 +189,43 @@ def check_has_nvlink(): def get_default_wandb_args(test_file: str, run_name_prefix: str | None = None, run_id: str | None = None): - if not os.environ.get("WANDB_API_KEY"): - print("Skip wandb configuration since WANDB_API_KEY is not found") + return get_default_tracking_args(test_file, run_name_prefix=run_name_prefix, run_id=run_id, backend="wandb") + + +def get_default_swanlab_args(test_file: str, run_name_prefix: str | None = None, run_id: str | None = None): + return get_default_tracking_args(test_file, run_name_prefix=run_name_prefix, run_id=run_id, backend="swanlab") + + +def get_default_tracking_args( + test_file: str, + run_name_prefix: str | None = None, + run_id: str | None = None, + backend: str | None = None, +): + if backend is None: + backend = "swanlab" if os.environ.get("SWANLAB_API_KEY") else "wandb" + + if backend == "wandb": + api_key_env = "WANDB_API_KEY" + use_flag = "--use-wandb" + project_flag = "--wandb-project" + group_flag = "--wandb-group" + key_flag = "--wandb-key" + suffix_flag = "--disable-wandb-random-suffix" + skip_msg = "Skip wandb configuration since WANDB_API_KEY is not found" + elif backend == "swanlab": + api_key_env = "SWANLAB_API_KEY" + use_flag = "--use-swanlab" + project_flag = "--swanlab-project" + group_flag = "--swanlab-group" + key_flag = "--swanlab-key" + suffix_flag = "--disable-swanlab-random-suffix" + skip_msg = "Skip swanlab configuration since SWANLAB_API_KEY is not found" + else: + raise ValueError(f"Unsupported backend: {backend}") + + if not os.environ.get(api_key_env): + print(skip_msg) return "" test_file = Path(test_file) @@ -198,20 +233,20 @@ def get_default_wandb_args(test_file: str, run_name_prefix: str | None = None, r if len(test_name) < 6: test_name = f"{test_file.parent.name}_{test_name}" - wandb_run_name = run_id or create_run_id() + run_name = run_id or create_run_id() if (x := os.environ.get("GITHUB_COMMIT_NAME")) is not None: - wandb_run_name += f"_{x}" + run_name += f"_{x}" if (x := run_name_prefix) is not None: - wandb_run_name = f"{x}_{wandb_run_name}" + run_name = f"{x}_{run_name}" # Use the actual key value from environment to avoid shell expansion issues - wandb_key = os.environ.get("WANDB_API_KEY") + api_key = os.environ.get(api_key_env) return ( - "--use-wandb " - f"--wandb-project slime-{test_name} " - f"--wandb-group {wandb_run_name} " - f"--wandb-key '{wandb_key}' " - "--disable-wandb-random-suffix " + f"{use_flag} " + f"{project_flag} slime-{test_name} " + f"{group_flag} {run_name} " + f"{key_flag} '{api_key}' " + f"{suffix_flag} " ) diff --git a/slime/utils/logging_utils.py b/slime/utils/logging_utils.py index 11348a4074..472675a55f 100644 --- a/slime/utils/logging_utils.py +++ b/slime/utils/logging_utils.py @@ -3,6 +3,7 @@ import wandb from . import wandb_utils +from . import swanlab_utils from .tensorboard_utils import _TensorboardAdapter _LOGGER_CONFIGURED = False @@ -26,23 +27,34 @@ def configure_logger(prefix: str = ""): def init_tracking(args, primary: bool = True, **kwargs): if primary: - wandb_utils.init_wandb_primary(args, **kwargs) + if args.use_wandb: + wandb_utils.init_wandb_primary(args, **kwargs) + if args.use_swanlab: + swanlab_utils.init_swanlab_primary(args) else: - wandb_utils.init_wandb_secondary(args, **kwargs) + if args.use_wandb: + wandb_utils.init_wandb_secondary(args, **kwargs) + if args.use_swanlab: + swanlab_utils.init_swanlab_secondary(args) def update_tracking_open_metrics(args, router_addr): - wandb_utils.reinit_wandb_primary_with_open_metrics(args, router_addr) + if args.use_wandb: + wandb_utils.reinit_wandb_primary_with_open_metrics(args, router_addr) + if args.use_swanlab: + swanlab_utils.reinit_swanlab_primary_with_open_metrics(args, router_addr) def finish_tracking(args): - if not args.use_wandb: - return - try: - if wandb.run is not None: - wandb.finish() - except Exception: - logging.getLogger(__name__).exception("Failed to finish wandb run") + if args.use_wandb: + try: + if wandb.run is not None: + wandb.finish() + except Exception: + logging.getLogger(__name__).exception("Failed to finish wandb run") + + if args.use_swanlab: + swanlab_utils.finish_swanlab(args) # TODO further refactor, e.g. put TensorBoard init to the "init" part @@ -50,6 +62,12 @@ def log(args, metrics, step_key: str): if args.use_wandb: wandb.log(metrics) + if args.use_swanlab: + import swanlab + + metrics_except_step = {k: v for k, v in metrics.items() if k != step_key} + swanlab.log(metrics_except_step, step=metrics[step_key]) + if args.use_tensorboard: metrics_except_step = {k: v for k, v in metrics.items() if k != step_key} _TensorboardAdapter(args).log(data=metrics_except_step, step=metrics[step_key]) diff --git a/slime/utils/swanlab_utils.py b/slime/utils/swanlab_utils.py new file mode 100644 index 0000000000..15f981c9e4 --- /dev/null +++ b/slime/utils/swanlab_utils.py @@ -0,0 +1,234 @@ +import logging +import os +import threading +import time +from copy import deepcopy +from dataclasses import dataclass + +import httpx + +try: + import swanlab +except ImportError: # pragma: no cover - optional dependency + swanlab = None + +logger = logging.getLogger(__name__) + +_OPEN_METRICS_MONITOR = None + + +def _require_swanlab(): + if swanlab is None: + raise ImportError("swanlab is not installed. Please install it with: pip install swanlab") + + +def _is_offline_mode(args) -> bool: + if args.swanlab_mode: + return args.swanlab_mode in {"offline", "local", "disabled"} + return os.environ.get("SWANLAB_MODE") in {"offline", "local", "disabled"} + + +def _maybe_login(args): + if _is_offline_mode(args): + return + if args.swanlab_key is not None or args.swanlab_host is not None or args.swanlab_web_host is not None: + swanlab.login(api_key=args.swanlab_key, host=args.swanlab_host, web_host=args.swanlab_web_host) + + +def _sanitize_metric_name(name: str) -> str: + return "".join(ch if (ch.isalnum() or ch == "_") else "_" for ch in name).strip("_") + + +def _parse_prometheus_metrics(text: str) -> dict[str, float]: + metrics = {} + for line in text.splitlines(): + line = line.strip() + if not line or line.startswith("#"): + continue + + parts = line.split() + if len(parts) < 2: + continue + + metric_expr = parts[0] + value_text = parts[1] + + try: + value = float(value_text) + except ValueError: + continue + + if "{" in metric_expr: + name, labels_text = metric_expr.split("{", 1) + labels_text = labels_text.rstrip("}") + label_parts = [] + for item in labels_text.split(","): + if not item or "=" not in item: + continue + key, raw_value = item.split("=", 1) + label_parts.append(f"{_sanitize_metric_name(key)}_{_sanitize_metric_name(raw_value.strip('\\"'))}") + metric_name = "_".join([name, *label_parts]) if label_parts else name + else: + metric_name = metric_expr + + metrics[_sanitize_metric_name(metric_name)] = value + + return metrics + + +@dataclass +class _SwanlabOpenMetricsMonitor: + args: object + router_addr: str + interval_s: int + + def __post_init__(self): + self._stop_event = threading.Event() + self._thread = None + self._poll_step = 0 + + def start(self): + if self._thread is not None: + return self + + self._thread = threading.Thread(target=self._run, name="swanlab-open-metrics", daemon=True) + self._thread.start() + return self + + def stop(self): + self._stop_event.set() + if self._thread is not None and self._thread.is_alive(): + self._thread.join(timeout=2) + + def _run(self): + while not self._stop_event.is_set(): + self.poll_once() + self._stop_event.wait(self.interval_s) + + def poll_once(self): + try: + response = httpx.get(f"{self.router_addr}/engine_metrics", timeout=10.0) + response.raise_for_status() + metrics = _parse_prometheus_metrics(response.text) + if not metrics: + return + + payload = {f"sglang_engine/{key}": value for key, value in metrics.items()} + swanlab.log(payload, step=self._poll_step) + self._poll_step += 1 + except Exception: + logger.exception("Failed to collect SwanLab open metrics from %s", self.router_addr) + + +def _build_init_kwargs(args, primary: bool): + group = args.swanlab_group + experiment_name = args.swanlab_experiment_name or group + + if args.swanlab_random_suffix and group: + group = f"{group}_{swanlab.util.generate_id()}" + if experiment_name is None: + experiment_name = group + + init_kwargs = { + "project": args.swanlab_project, + "workspace": args.swanlab_workspace, + "group": group, + "experiment_name": experiment_name, + "config": _compute_config_for_logging(args), + "mode": args.swanlab_mode, + "id": getattr(args, "swanlab_run_id", None), + "resume": "allow" if getattr(args, "swanlab_run_id", None) is not None else None, + "reinit": True, + } + + if args.swanlab_mode in (None, "cloud"): + init_kwargs["parallel"] = "shared" + + if args.swanlab_dir: + os.makedirs(args.swanlab_dir, exist_ok=True) + init_kwargs["logdir"] = args.swanlab_dir + logger.info(f"SwanLab logs will be stored in: {args.swanlab_dir}") + + if primary and getattr(args, "rank", 0) != 0: + init_kwargs["reinit"] = True + + return init_kwargs + + +def init_swanlab_primary(args): + if not args.use_swanlab: + args.swanlab_run_id = None + return + + _require_swanlab() + _maybe_login(args) + + init_kwargs = _build_init_kwargs(args, primary=True) + swanlab.init(**init_kwargs) + args.swanlab_run_id = swanlab.get_run().id + + +def reinit_swanlab_primary_with_open_metrics(args, router_addr): + global _OPEN_METRICS_MONITOR + + if not args.use_swanlab: + return + if router_addr is None: + return + + _require_swanlab() + if _is_offline_mode(args): + logger.info("SwanLab open metrics disabled in offline/local/disabled mode.") + return + + if _OPEN_METRICS_MONITOR is not None: + _OPEN_METRICS_MONITOR.stop() + + logger.info(f"Starting SwanLab open metrics monitor at {router_addr}.") + _OPEN_METRICS_MONITOR = _SwanlabOpenMetricsMonitor( + args=args, + router_addr=router_addr, + interval_s=max(int(getattr(args, "swanlab_open_metrics_interval", 10)), 1), + ).start() + + +def init_swanlab_secondary(args): + if not args.use_swanlab: + return + + wandb_run_id = getattr(args, "swanlab_run_id", None) + if wandb_run_id is None: + return + + _require_swanlab() + _maybe_login(args) + + init_kwargs = _build_init_kwargs(args, primary=False) + init_kwargs["id"] = wandb_run_id + init_kwargs["resume"] = "allow" + swanlab.init(**init_kwargs) + + +def finish_swanlab(args): + global _OPEN_METRICS_MONITOR + + if not args.use_swanlab: + return + if swanlab is None: + return + try: + if _OPEN_METRICS_MONITOR is not None: + _OPEN_METRICS_MONITOR.stop() + _OPEN_METRICS_MONITOR = None + run = swanlab.get_run() + if run is not None: + swanlab.finish() + except Exception: + logging.getLogger(__name__).exception("Failed to finish SwanLab run") + + +def _compute_config_for_logging(args): + output = deepcopy(args.__dict__) + whitelist_env_vars = ["SLURM_JOB_ID"] + output["env_vars"] = {k: v for k, v in os.environ.items() if k in whitelist_env_vars} + return output diff --git a/slime_plugins/rollout_buffer/rollout_buffer_example.py b/slime_plugins/rollout_buffer/rollout_buffer_example.py index 74b4b4c46c..aa1122dbd4 100644 --- a/slime_plugins/rollout_buffer/rollout_buffer_example.py +++ b/slime_plugins/rollout_buffer/rollout_buffer_example.py @@ -4,11 +4,11 @@ import aiohttp import requests -import wandb from transformers import AutoTokenizer from slime.utils.async_utils import run from slime.utils.mask_utils import MultiTurnLossMaskGenerator +from slime.utils import logging_utils from slime.utils.types import Sample __all__ = ["generate_rollout"] @@ -107,7 +107,7 @@ def log_raw_info(args, all_meta_info, rollout_id): "avg_reward": weighted_reward_sum / total_samples, } ) - if args.use_wandb: + if args.use_wandb or args.use_swanlab or args.use_tensorboard: log_dict = { "rollout/no_filter/total_samples": final_meta_info["total_samples"], "rollout/no_filter/avg_reward": final_meta_info["avg_reward"], @@ -118,18 +118,11 @@ def log_raw_info(args, all_meta_info, rollout_id): if not args.wandb_always_use_train_step else rollout_id * args.rollout_batch_size * args.n_samples_per_prompt // args.global_batch_size ) - if args.use_wandb: - log_dict["rollout/step"] = step - wandb.log(log_dict) - - if args.use_tensorboard: - from slime.utils.tensorboard_utils import _TensorboardAdapter - - tb = _TensorboardAdapter(args) - tb.log(data=log_dict, step=step) + log_dict["rollout/step"] = step + logging_utils.log(args, log_dict, step_key="rollout/step") print(f"no filter rollout log {rollout_id}: {log_dict}") except Exception as e: - print(f"Failed to log to wandb: {e}") + print(f"Failed to log metrics: {e}") print(f"no filter rollout log {rollout_id}: {final_meta_info}") else: print(f"no filter rollout log {rollout_id}: {final_meta_info}") diff --git a/tests/test_glm4.7_30B_A3B_pd_mooncake.py b/tests/test_glm4.7_30B_A3B_pd_mooncake.py index 8ea17b5ce9..f839e246c8 100644 --- a/tests/test_glm4.7_30B_A3B_pd_mooncake.py +++ b/tests/test_glm4.7_30B_A3B_pd_mooncake.py @@ -145,7 +145,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{sglang_args} " f"{misc_args} " diff --git a/tests/test_mimo_7B_mtp_only_grad.py b/tests/test_mimo_7B_mtp_only_grad.py index 9e5e3d7d33..ca1c8b0fb3 100644 --- a/tests/test_mimo_7B_mtp_only_grad.py +++ b/tests/test_mimo_7B_mtp_only_grad.py @@ -124,7 +124,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{sglang_args} " f"{mtp_args} " diff --git a/tests/test_moonlight_16B_A3B.py b/tests/test_moonlight_16B_A3B.py index cd5818ce95..415279cefc 100644 --- a/tests/test_moonlight_16B_A3B.py +++ b/tests/test_moonlight_16B_A3B.py @@ -104,7 +104,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_moonlight_16B_A3B_r3.py b/tests/test_moonlight_16B_A3B_r3.py index facb2d0b70..5f8ac53ef2 100644 --- a/tests/test_moonlight_16B_A3B_r3.py +++ b/tests/test_moonlight_16B_A3B_r3.py @@ -105,7 +105,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_quick_start_glm4_9B.py b/tests/test_quick_start_glm4_9B.py index 30cb348594..b8bfa33505 100644 --- a/tests/test_quick_start_glm4_9B.py +++ b/tests/test_quick_start_glm4_9B.py @@ -103,7 +103,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen2.5_0.5B_async_short.py b/tests/test_qwen2.5_0.5B_async_short.py index c3925d4432..9081cacf6d 100644 --- a/tests/test_qwen2.5_0.5B_async_short.py +++ b/tests/test_qwen2.5_0.5B_async_short.py @@ -96,7 +96,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{sglang_args} " f"{ci_args} " diff --git a/tests/test_qwen2.5_0.5B_opd_sglang.py b/tests/test_qwen2.5_0.5B_opd_sglang.py index eb0892fb91..e3d18756a7 100644 --- a/tests/test_qwen2.5_0.5B_opd_sglang.py +++ b/tests/test_qwen2.5_0.5B_opd_sglang.py @@ -180,7 +180,7 @@ def launch_teacher(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen2.5_0.5B_ppo_critic_only_short.py b/tests/test_qwen2.5_0.5B_ppo_critic_only_short.py index c03d863d8e..37bfe797df 100644 --- a/tests/test_qwen2.5_0.5B_ppo_critic_only_short.py +++ b/tests/test_qwen2.5_0.5B_ppo_critic_only_short.py @@ -110,7 +110,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{ppo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{sglang_args} " f"{ci_args} " diff --git a/tests/test_qwen2.5_0.5B_sglang_config.py b/tests/test_qwen2.5_0.5B_sglang_config.py index f30c6ac6cf..04bfc8c8c2 100644 --- a/tests/test_qwen2.5_0.5B_sglang_config.py +++ b/tests/test_qwen2.5_0.5B_sglang_config.py @@ -125,7 +125,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen2.5_0.5B_sglang_config_distributed.py b/tests/test_qwen2.5_0.5B_sglang_config_distributed.py index 68215b34c6..753ee76b4b 100644 --- a/tests/test_qwen2.5_0.5B_sglang_config_distributed.py +++ b/tests/test_qwen2.5_0.5B_sglang_config_distributed.py @@ -127,7 +127,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen2.5_0.5B_short.py b/tests/test_qwen2.5_0.5B_short.py index 6f45095bfb..14b3983c60 100644 --- a/tests/test_qwen2.5_0.5B_short.py +++ b/tests/test_qwen2.5_0.5B_short.py @@ -96,7 +96,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{sglang_args} " f"{ci_args} " diff --git a/tests/test_qwen3.5_0.8B_gsm8k_async_short.py b/tests/test_qwen3.5_0.8B_gsm8k_async_short.py index 14d92ff4fe..161112520e 100644 --- a/tests/test_qwen3.5_0.8B_gsm8k_async_short.py +++ b/tests/test_qwen3.5_0.8B_gsm8k_async_short.py @@ -114,7 +114,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen3.5_0.8B_gsm8k_short.py b/tests/test_qwen3.5_0.8B_gsm8k_short.py index 856413de75..792c7ff753 100644 --- a/tests/test_qwen3.5_0.8B_gsm8k_short.py +++ b/tests/test_qwen3.5_0.8B_gsm8k_short.py @@ -114,7 +114,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen3.6_35B_A3B_pd_mooncake.py b/tests/test_qwen3.6_35B_A3B_pd_mooncake.py index 83e2449dae..da7737deb5 100644 --- a/tests/test_qwen3.6_35B_A3B_pd_mooncake.py +++ b/tests/test_qwen3.6_35B_A3B_pd_mooncake.py @@ -132,7 +132,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen3_0.6B_parallel_check.py b/tests/test_qwen3_0.6B_parallel_check.py index 6213ad1838..12c7331a89 100644 --- a/tests/test_qwen3_0.6B_parallel_check.py +++ b/tests/test_qwen3_0.6B_parallel_check.py @@ -84,7 +84,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{ppo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{sglang_args} " f"{ci_args} " f"{misc_args} " diff --git a/tests/test_qwen3_30B_A3B.py b/tests/test_qwen3_30B_A3B.py index 976bbec0ae..ddb4c34447 100644 --- a/tests/test_qwen3_30B_A3B.py +++ b/tests/test_qwen3_30B_A3B.py @@ -128,7 +128,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen3_30B_A3B_r3.py b/tests/test_qwen3_30B_A3B_r3.py index 71d87a3919..2c6faaccb0 100644 --- a/tests/test_qwen3_30B_A3B_r3.py +++ b/tests/test_qwen3_30B_A3B_r3.py @@ -128,7 +128,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen3_4B_ckpt.py b/tests/test_qwen3_4B_ckpt.py index 4d0e082de4..d06d53cb31 100644 --- a/tests/test_qwen3_4B_ckpt.py +++ b/tests/test_qwen3_4B_ckpt.py @@ -113,7 +113,7 @@ def execute(mode: str = ""): f"{rollout_args} " f"{optimizer_args} " f"{ppo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{sglang_args} " f"{ci_args} " diff --git a/tests/test_qwen3_4B_ppo.py b/tests/test_qwen3_4B_ppo.py index 9b9105f499..033fe33f6e 100644 --- a/tests/test_qwen3_4B_ppo.py +++ b/tests/test_qwen3_4B_ppo.py @@ -128,7 +128,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{ppo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen3_4B_ppo_disaggregate.py b/tests/test_qwen3_4B_ppo_disaggregate.py index 52119a5b67..3277caedc6 100644 --- a/tests/test_qwen3_4B_ppo_disaggregate.py +++ b/tests/test_qwen3_4B_ppo_disaggregate.py @@ -127,7 +127,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{ppo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_qwen3_4B_ppo_train_critic_only.py b/tests/test_qwen3_4B_ppo_train_critic_only.py index f0e1b19e08..2c22cd8591 100644 --- a/tests/test_qwen3_4B_ppo_train_critic_only.py +++ b/tests/test_qwen3_4B_ppo_train_critic_only.py @@ -127,7 +127,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{ppo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_sglang_config_mixed_offload.py b/tests/test_sglang_config_mixed_offload.py index 90d3d97389..e22948f286 100644 --- a/tests/test_sglang_config_mixed_offload.py +++ b/tests/test_sglang_config_mixed_offload.py @@ -134,7 +134,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/test_sglang_config_mixed_offload_ft.py b/tests/test_sglang_config_mixed_offload_ft.py index f017965c5f..7f8ab6a7d1 100644 --- a/tests/test_sglang_config_mixed_offload_ft.py +++ b/tests/test_sglang_config_mixed_offload_ft.py @@ -140,7 +140,7 @@ def execute(): f"{rollout_args} " f"{optimizer_args} " f"{grpo_args} " - f"{U.get_default_wandb_args(__file__)} " + f"{U.get_default_tracking_args(__file__)} " f"{perf_args} " f"{eval_args} " f"{sglang_args} " diff --git a/tests/utils/test_swanlab_support.py b/tests/utils/test_swanlab_support.py new file mode 100644 index 0000000000..4cfe25f5f8 --- /dev/null +++ b/tests/utils/test_swanlab_support.py @@ -0,0 +1,95 @@ +import sys +from types import SimpleNamespace + +import slime.utils.external_utils.command_utils as U + + +def test_get_default_swanlab_args(monkeypatch, tmp_path): + monkeypatch.setenv("SWANLAB_API_KEY", "test-key") + + args = U.get_default_swanlab_args(str(tmp_path / "test_short.py"), run_name_prefix="prefix", run_id="run123") + + assert "--use-swanlab" in args + assert "--swanlab-project slime-test_short" in args + assert "--swanlab-group prefix_run123" in args + assert "--swanlab-key 'test-key'" in args + assert "--disable-swanlab-random-suffix" in args + + +def test_get_default_tracking_args_can_target_wandb(monkeypatch, tmp_path): + monkeypatch.setenv("WANDB_API_KEY", "wandb-key") + + args = U.get_default_tracking_args(str(tmp_path / "test_short.py"), run_name_prefix="prefix", run_id="run123", backend="wandb") + + assert "--use-wandb" in args + assert "--wandb-project slime-test_short" in args + assert "--wandb-group prefix_run123" in args + assert "--wandb-key 'wandb-key'" in args + assert "--disable-wandb-random-suffix" in args + + +def test_logging_utils_distributes_to_wandb_and_swanlab(monkeypatch): + import slime.utils.logging_utils as logging_utils + + wandb_calls = {"log": [], "finish": 0} + swanlab_calls = {"log": [], "finish": 0, "init": []} + + wandb_stub = SimpleNamespace( + run=SimpleNamespace(id="wandb-run"), + log=lambda metrics: wandb_calls["log"].append(metrics), + finish=lambda: wandb_calls.__setitem__("finish", wandb_calls["finish"] + 1), + ) + swanlab_stub = SimpleNamespace( + log=lambda metrics, step=None: swanlab_calls["log"].append((metrics, step)), + finish=lambda: swanlab_calls.__setitem__("finish", swanlab_calls["finish"] + 1), + init=lambda **kwargs: swanlab_calls["init"].append(kwargs), + get_run=lambda: SimpleNamespace(id="swanlab-run"), + util=SimpleNamespace(generate_id=lambda: "abc123"), + ) + + monkeypatch.setitem(sys.modules, "swanlab", swanlab_stub) + monkeypatch.setattr(logging_utils, "wandb", wandb_stub) + monkeypatch.setattr(logging_utils, "swanlab_utils", SimpleNamespace( + init_swanlab_primary=lambda args: swanlab_calls["init"].append({"primary": True}), + init_swanlab_secondary=lambda args: swanlab_calls["init"].append({"secondary": True}), + reinit_swanlab_primary_with_open_metrics=lambda args, router_addr: swanlab_calls["init"].append( + {"router_addr": router_addr} + ), + finish_swanlab=lambda args: swanlab_calls.__setitem__("finish", swanlab_calls["finish"] + 1), + )) + monkeypatch.setitem(logging_utils.__dict__, "_TensorboardAdapter", lambda args: SimpleNamespace(log=lambda **kwargs: None)) + monkeypatch.setitem(logging_utils.__dict__, "wandb_utils", SimpleNamespace( + init_wandb_primary=lambda args, **kwargs: None, + init_wandb_secondary=lambda args, **kwargs: None, + reinit_wandb_primary_with_open_metrics=lambda args, router_addr: None, + )) + + args = SimpleNamespace(use_wandb=True, use_swanlab=True, use_tensorboard=False) + logging_utils.log(args, {"rollout/step": 3, "metric": 1.5}, step_key="rollout/step") + + assert wandb_calls["log"] == [{"rollout/step": 3, "metric": 1.5}] + assert swanlab_calls["log"] == [({"metric": 1.5}, 3)] + + +def test_swanlab_open_metrics_monitor_collects_and_logs(monkeypatch): + import slime.utils.swanlab_utils as swanlab_utils + + logged = [] + + monkeypatch.setattr(swanlab_utils, "swanlab", SimpleNamespace(log=lambda metrics, step=None: logged.append((metrics, step)))) + monkeypatch.setattr( + swanlab_utils, + "httpx", + SimpleNamespace(get=lambda url, timeout=10.0: SimpleNamespace( + raise_for_status=lambda: None, + text="""# HELP sglang_requests_total requests +sglang_requests_total{engine=rollout-0} 12 +sglang_latency_seconds 0.5 +""", + )), + ) + + monitor = swanlab_utils._SwanlabOpenMetricsMonitor(args=SimpleNamespace(), router_addr="http://127.0.0.1:8000", interval_s=1) + monitor.poll_once() + + assert logged == [({"sglang_engine/sglang_requests_total_engine_rollout_0": 12.0, "sglang_engine/sglang_latency_seconds": 0.5}, 0)]