Skip to content

Commit 73de8a3

Browse files
committed
Changes Wandb runner name to log directory name
Approved-by: Clemens Schwarke
1 parent efbdf68 commit 73de8a3

File tree

1 file changed

+22
-17
lines changed

1 file changed

+22
-17
lines changed

rsl_rl/utils/wandb_utils.py

Lines changed: 22 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ class WandbSummaryWriter(SummaryWriter):
2121
def __init__(self, log_dir: str, flush_secs: int, cfg):
2222
super().__init__(log_dir, flush_secs)
2323

24+
# Get the run name
25+
run_name = os.path.split(log_dir)[-1]
26+
2427
try:
2528
project = cfg["wandb_project"]
2629
except KeyError:
@@ -29,35 +32,27 @@ def __init__(self, log_dir: str, flush_secs: int, cfg):
2932
try:
3033
entity = os.environ["WANDB_USERNAME"]
3134
except KeyError:
32-
raise KeyError(
33-
"Wandb username not found. Please run or add to ~/.bashrc: export WANDB_USERNAME=YOUR_USERNAME"
34-
)
35+
entity = None
3536

36-
wandb.init(project=project, entity=entity)
37+
# Initialize wandb
38+
wandb.init(project=project, entity=entity, name=run_name)
3739

38-
# Change generated name to project-number format
39-
wandb.run.name = project + wandb.run.name.split("-")[-1]
40+
# Add log directory to wandb
41+
wandb.config.update({"log_dir": log_dir})
4042

4143
self.name_map = {
4244
"Train/mean_reward/time": "Train/mean_reward_time",
4345
"Train/mean_episode_length/time": "Train/mean_episode_length_time",
4446
}
4547

46-
run_name = os.path.split(log_dir)[-1]
47-
48-
wandb.log({"log_dir": run_name})
49-
5048
def store_config(self, env_cfg, runner_cfg, alg_cfg, policy_cfg):
5149
wandb.config.update({"runner_cfg": runner_cfg})
5250
wandb.config.update({"policy_cfg": policy_cfg})
5351
wandb.config.update({"alg_cfg": alg_cfg})
54-
wandb.config.update({"env_cfg": asdict(env_cfg)})
55-
56-
def _map_path(self, path):
57-
if path in self.name_map:
58-
return self.name_map[path]
59-
else:
60-
return path
52+
try:
53+
wandb.config.update({"env_cfg": env_cfg.to_dict()})
54+
except Exception:
55+
wandb.config.update({"env_cfg": asdict(env_cfg)})
6156

6257
def add_scalar(self, tag, scalar_value, global_step=None, walltime=None, new_style=False):
6358
super().add_scalar(
@@ -80,3 +75,13 @@ def save_model(self, model_path, iter):
8075

8176
def save_file(self, path, iter=None):
8277
wandb.save(path, base_path=os.path.dirname(path))
78+
79+
"""
80+
Private methods.
81+
"""
82+
83+
def _map_path(self, path):
84+
if path in self.name_map:
85+
return self.name_map[path]
86+
else:
87+
return path

0 commit comments

Comments
 (0)