@@ -21,6 +21,9 @@ class WandbSummaryWriter(SummaryWriter):
21
21
def __init__ (self , log_dir : str , flush_secs : int , cfg ):
22
22
super ().__init__ (log_dir , flush_secs )
23
23
24
+ # Get the run name
25
+ run_name = os .path .split (log_dir )[- 1 ]
26
+
24
27
try :
25
28
project = cfg ["wandb_project" ]
26
29
except KeyError :
@@ -29,35 +32,27 @@ def __init__(self, log_dir: str, flush_secs: int, cfg):
29
32
try :
30
33
entity = os .environ ["WANDB_USERNAME" ]
31
34
except KeyError :
32
- raise KeyError (
33
- "Wandb username not found. Please run or add to ~/.bashrc: export WANDB_USERNAME=YOUR_USERNAME"
34
- )
35
+ entity = None
35
36
36
- wandb .init (project = project , entity = entity )
37
+ # Initialize wandb
38
+ wandb .init (project = project , entity = entity , name = run_name )
37
39
38
- # Change generated name to project-number format
39
- wandb .run . name = project + wandb . run . name . split ( "-" )[ - 1 ]
40
+ # Add log directory to wandb
41
+ wandb .config . update ({ "log_dir" : log_dir })
40
42
41
43
self .name_map = {
42
44
"Train/mean_reward/time" : "Train/mean_reward_time" ,
43
45
"Train/mean_episode_length/time" : "Train/mean_episode_length_time" ,
44
46
}
45
47
46
- run_name = os .path .split (log_dir )[- 1 ]
47
-
48
- wandb .log ({"log_dir" : run_name })
49
-
50
48
def store_config (self , env_cfg , runner_cfg , alg_cfg , policy_cfg ):
51
49
wandb .config .update ({"runner_cfg" : runner_cfg })
52
50
wandb .config .update ({"policy_cfg" : policy_cfg })
53
51
wandb .config .update ({"alg_cfg" : alg_cfg })
54
- wandb .config .update ({"env_cfg" : asdict (env_cfg )})
55
-
56
- def _map_path (self , path ):
57
- if path in self .name_map :
58
- return self .name_map [path ]
59
- else :
60
- return path
52
+ try :
53
+ wandb .config .update ({"env_cfg" : env_cfg .to_dict ()})
54
+ except Exception :
55
+ wandb .config .update ({"env_cfg" : asdict (env_cfg )})
61
56
62
57
def add_scalar (self , tag , scalar_value , global_step = None , walltime = None , new_style = False ):
63
58
super ().add_scalar (
@@ -80,3 +75,13 @@ def save_model(self, model_path, iter):
80
75
81
76
def save_file (self , path , iter = None ):
82
77
wandb .save (path , base_path = os .path .dirname (path ))
78
+
79
+ """
80
+ Private methods.
81
+ """
82
+
83
+ def _map_path (self , path ):
84
+ if path in self .name_map :
85
+ return self .name_map [path ]
86
+ else :
87
+ return path
0 commit comments