diff --git a/pufferlib/pufferl.py b/pufferlib/pufferl.py index 1c80de553..a2f34b324 100644 --- a/pufferlib/pufferl.py +++ b/pufferlib/pufferl.py @@ -1038,6 +1038,19 @@ class WandbLogger: def __init__(self, args, load_id=None, resume="allow"): import wandb + try: + git_branch = subprocess.check_output(["git", "branch", "--show-current"], text=True).strip() + git_commit = subprocess.check_output(["git", "rev-parse", "HEAD"], text=True).strip() + # Get the latest commit message (subject and body) + git_commit_message = subprocess.check_output(["git", "log", "-1", "--pretty=%B"], text=True).strip() + + # Format notes for the overview section + git_notes = f"**GitHub Repo State**\n\nBranch Name: {git_branch}\n\nLatest Commit Id: {git_commit}\n\nLatest Commit Message: {git_commit_message}" + except: + git_notes = ( + "Error fetching git info. Make sure you're running this in a git repository and have git installed." + ) + wandb.init( id=load_id or wandb.util.generate_id(), project=args["wandb_project"], @@ -1047,6 +1060,7 @@ def __init__(self, args, load_id=None, resume="allow"): resume=resume, config=args, name=args.get("wandb_name"), + notes=git_notes, tags=[args["tag"]] if args["tag"] is not None else [], ) self.wandb = wandb