18
18
task_cfg = TASK_CONFIG
19
19
train_cfg = TRAIN_CONFIG
20
20
policy_config = POLICY_CONFIG
21
+ checkpoint_dir = os .path .join (train_cfg ['checkpoint_dir' ], task )
21
22
22
23
# device
23
24
device = os .environ ['DEVICE' ]
@@ -54,7 +55,7 @@ def train_bc(train_dataloader, val_dataloader, policy_config):
54
55
optimizer = make_optimizer (policy_config ['policy_class' ], policy )
55
56
56
57
# create checkpoint dir if not exists
57
- os .makedirs (train_cfg [ ' checkpoint_dir' ] , exist_ok = True )
58
+ os .makedirs (checkpoint_dir , exist_ok = True )
58
59
59
60
train_history = []
60
61
validation_history = []
@@ -102,19 +103,19 @@ def train_bc(train_dataloader, val_dataloader, policy_config):
102
103
print (summary_string )
103
104
104
105
if epoch % 200 == 0 :
105
- ckpt_path = os .path .join (train_cfg [ ' checkpoint_dir' ] , f"policy_epoch_{ epoch } _seed_{ train_cfg ['seed' ]} .ckpt" )
106
+ ckpt_path = os .path .join (checkpoint_dir , f"policy_epoch_{ epoch } _seed_{ train_cfg ['seed' ]} .ckpt" )
106
107
torch .save (policy .state_dict (), ckpt_path )
107
- plot_history (train_history , validation_history , epoch , train_cfg [ ' checkpoint_dir' ] , train_cfg ['seed' ])
108
+ plot_history (train_history , validation_history , epoch , checkpoint_dir , train_cfg ['seed' ])
108
109
109
- ckpt_path = os .path .join (train_cfg [ ' checkpoint_dir' ] , f'policy_last.ckpt' )
110
+ ckpt_path = os .path .join (checkpoint_dir , f'policy_last.ckpt' )
110
111
torch .save (policy .state_dict (), ckpt_path )
111
112
112
113
113
114
if __name__ == '__main__' :
114
115
# set seed
115
116
set_seed (train_cfg ['seed' ])
116
117
# create ckpt dir if not exists
117
- os .makedirs (train_cfg [ ' checkpoint_dir' ] , exist_ok = True )
118
+ os .makedirs (checkpoint_dir , exist_ok = True )
118
119
# number of training episodes
119
120
data_dir = os .path .join (task_cfg ['dataset_dir' ], task )
120
121
num_episodes = len (os .listdir (data_dir ))
@@ -123,7 +124,7 @@ def train_bc(train_dataloader, val_dataloader, policy_config):
123
124
train_dataloader , val_dataloader , stats , _ = load_data (data_dir , num_episodes , task_cfg ['camera_names' ],
124
125
train_cfg ['batch_size_train' ], train_cfg ['batch_size_val' ])
125
126
# save stats
126
- stats_path = os .path .join (train_cfg [ ' checkpoint_dir' ] , f'dataset_stats.pkl' )
127
+ stats_path = os .path .join (checkpoint_dir , f'dataset_stats.pkl' )
127
128
with open (stats_path , 'wb' ) as f :
128
129
pickle .dump (stats , f )
129
130
0 commit comments