Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 11 additions & 8 deletions examples/retool/generate_with_retool.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from slime.rollout.sglang_rollout import GenerateState
from slime.utils.http_utils import post
from slime.utils import logging_utils
from slime.utils.types import Sample

# Import reward models
Expand Down Expand Up @@ -248,26 +249,28 @@ async def generate(args, sample: Sample, sampling_params) -> Sample:
"return_logprob": True, # Request log probabilities for training
}

# Log payload to wandb for debugging
# Log payload for debugging
try:
import wandb

if wandb.run is not None:
if getattr(args, "use_wandb", False) or getattr(args, "use_swanlab", False) or getattr(
args, "use_tensorboard", False
):
# Count available tools (from tool_specs)
available_tools = len(tool_specs)
# Count tools used in the current response
tools_used = response.count("<interpreter>")

wandb.log(
logging_utils.log(
args,
{
"debug/payload_length": len(prompt + response),
"debug/available_tools": available_tools,
"debug/tools_used": tools_used,
"debug/turn": turn,
}
},
step_key="debug/turn",
)
except ImportError:
pass # wandb not available
except Exception:
pass
Comment thread
asckaya marked this conversation as resolved.

output = await post(url, payload)

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ sections = ["FUTURE", "STDLIB", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"]
default_section = "THIRDPARTY"
extend_skip = ["setup.py", "docs/source/conf.py"]
known_first_party = ["slime", "slime_plugins"]
known_third_party = ["megatron", "wandb", "ray", "transformers"]
known_third_party = ["megatron", "ray", "swanlab", "transformers", "wandb"]
src_paths = ["slime", "slime_plugins"]


Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ qwen_vl_utils # for VLM
ray[default]
ring_flash_attn
sglang-router>=0.2.3
swanlab
tensorboard
transformers
wandb
35 changes: 35 additions & 0 deletions slime/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1095,6 +1095,40 @@ def add_wandb_arguments(parser):
parser.add_argument("--wandb-run-id", type=str, default=None)
return parser

# swanlab
def add_swanlab_arguments(parser):
parser.add_argument("--use-swanlab", action="store_true", default=False)
parser.add_argument(
"--swanlab-mode",
type=str,
default=None,
choices=["cloud", "local", "offline", "disabled"],
help="SwanLab mode: cloud (default), local, offline, or disabled.",
)
parser.add_argument(
"--swanlab-dir",
type=str,
default=None,
help="Directory to store SwanLab logs. Default is swanlog in the current directory.",
)
parser.add_argument("--swanlab-key", type=str, default=None)
parser.add_argument("--swanlab-host", type=str, default=None)
parser.add_argument("--swanlab-web-host", type=str, default=None)
parser.add_argument("--swanlab-workspace", type=str, default=None)
parser.add_argument("--swanlab-project", type=str, default=None)
parser.add_argument("--swanlab-group", type=str, default=None)
parser.add_argument("--swanlab-experiment-name", type=str, default=None)
parser.add_argument("--swanlab-open-metrics-interval", type=int, default=10)
parser.add_argument(
"--disable-swanlab-random-suffix",
action="store_false",
dest="swanlab_random_suffix",
default=True,
help="Whether to add a random suffix to the SwanLab group name.",
)
parser.add_argument("--swanlab-run-id", type=str, default=None)
return parser

# tensorboard
def add_tensorboard_arguments(parser):
# tb_project_name, tb_experiment_name
Expand Down Expand Up @@ -1379,6 +1413,7 @@ def add_ci_arguments(parser):
parser = add_algo_arguments(parser)
parser = add_on_policy_distillation_arguments(parser)
parser = add_wandb_arguments(parser)
parser = add_swanlab_arguments(parser)
parser = add_tensorboard_arguments(parser)
parser = add_router_arguments(parser)
parser = add_debug_arguments(parser)
Expand Down
57 changes: 46 additions & 11 deletions slime/utils/external_utils/command_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,29 +189,64 @@ def check_has_nvlink():


def get_default_wandb_args(test_file: str, run_name_prefix: str | None = None, run_id: str | None = None):
if not os.environ.get("WANDB_API_KEY"):
print("Skip wandb configuration since WANDB_API_KEY is not found")
return get_default_tracking_args(test_file, run_name_prefix=run_name_prefix, run_id=run_id, backend="wandb")


def get_default_swanlab_args(test_file: str, run_name_prefix: str | None = None, run_id: str | None = None):
return get_default_tracking_args(test_file, run_name_prefix=run_name_prefix, run_id=run_id, backend="swanlab")


def get_default_tracking_args(
test_file: str,
run_name_prefix: str | None = None,
run_id: str | None = None,
backend: str | None = None,
):
if backend is None:
backend = "swanlab" if os.environ.get("SWANLAB_API_KEY") else "wandb"

if backend == "wandb":
api_key_env = "WANDB_API_KEY"
use_flag = "--use-wandb"
project_flag = "--wandb-project"
group_flag = "--wandb-group"
key_flag = "--wandb-key"
suffix_flag = "--disable-wandb-random-suffix"
skip_msg = "Skip wandb configuration since WANDB_API_KEY is not found"
elif backend == "swanlab":
api_key_env = "SWANLAB_API_KEY"
use_flag = "--use-swanlab"
project_flag = "--swanlab-project"
group_flag = "--swanlab-group"
key_flag = "--swanlab-key"
suffix_flag = "--disable-swanlab-random-suffix"
skip_msg = "Skip swanlab configuration since SWANLAB_API_KEY is not found"
else:
raise ValueError(f"Unsupported backend: {backend}")

if not os.environ.get(api_key_env):
print(skip_msg)
return ""

test_file = Path(test_file)
test_name = test_file.stem
if len(test_name) < 6:
test_name = f"{test_file.parent.name}_{test_name}"

wandb_run_name = run_id or create_run_id()
run_name = run_id or create_run_id()
if (x := os.environ.get("GITHUB_COMMIT_NAME")) is not None:
wandb_run_name += f"_{x}"
run_name += f"_{x}"
if (x := run_name_prefix) is not None:
wandb_run_name = f"{x}_{wandb_run_name}"
run_name = f"{x}_{run_name}"

# Use the actual key value from environment to avoid shell expansion issues
wandb_key = os.environ.get("WANDB_API_KEY")
api_key = os.environ.get(api_key_env)
return (
"--use-wandb "
f"--wandb-project slime-{test_name} "
f"--wandb-group {wandb_run_name} "
f"--wandb-key '{wandb_key}' "
"--disable-wandb-random-suffix "
f"{use_flag} "
f"{project_flag} slime-{test_name} "
f"{group_flag} {run_name} "
f"{key_flag} '{api_key}' "
f"{suffix_flag} "
)


Expand Down
38 changes: 28 additions & 10 deletions slime/utils/logging_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import wandb

from . import wandb_utils
from . import swanlab_utils
from .tensorboard_utils import _TensorboardAdapter

_LOGGER_CONFIGURED = False
Expand All @@ -26,30 +27,47 @@ def configure_logger(prefix: str = ""):

def init_tracking(args, primary: bool = True, **kwargs):
if primary:
wandb_utils.init_wandb_primary(args, **kwargs)
if args.use_wandb:
wandb_utils.init_wandb_primary(args, **kwargs)
if args.use_swanlab:
swanlab_utils.init_swanlab_primary(args)
else:
wandb_utils.init_wandb_secondary(args, **kwargs)
if args.use_wandb:
wandb_utils.init_wandb_secondary(args, **kwargs)
if args.use_swanlab:
swanlab_utils.init_swanlab_secondary(args)


def update_tracking_open_metrics(args, router_addr):
wandb_utils.reinit_wandb_primary_with_open_metrics(args, router_addr)
if args.use_wandb:
wandb_utils.reinit_wandb_primary_with_open_metrics(args, router_addr)
if args.use_swanlab:
swanlab_utils.reinit_swanlab_primary_with_open_metrics(args, router_addr)


def finish_tracking(args):
if not args.use_wandb:
return
try:
if wandb.run is not None:
wandb.finish()
except Exception:
logging.getLogger(__name__).exception("Failed to finish wandb run")
if args.use_wandb:
try:
if wandb.run is not None:
wandb.finish()
except Exception:
logging.getLogger(__name__).exception("Failed to finish wandb run")

if args.use_swanlab:
swanlab_utils.finish_swanlab(args)


# TODO further refactor, e.g. put TensorBoard init to the "init" part
def log(args, metrics, step_key: str):
if args.use_wandb:
wandb.log(metrics)

if args.use_swanlab:
import swanlab

metrics_except_step = {k: v for k, v in metrics.items() if k != step_key}
swanlab.log(metrics_except_step, step=metrics[step_key])
Comment thread
asckaya marked this conversation as resolved.

if args.use_tensorboard:
metrics_except_step = {k: v for k, v in metrics.items() if k != step_key}
_TensorboardAdapter(args).log(data=metrics_except_step, step=metrics[step_key])
Loading