Skip to content

Commit 791639f

Browse files
committed
fix checkpoint dir bug
1 parent 43502cd commit 791639f

File tree

2 files changed

+8
-7
lines changed

2 files changed

+8
-7
lines changed

config/config.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -60,5 +60,5 @@
6060
'batch_size_val': 8,
6161
'batch_size_train': 8,
6262
'eval_ckpt_name': 'policy_last.ckpt',
63-
'checkpoint_dir': os.path.join(CHECKPOINT_DIR, 'sort')
63+
'checkpoint_dir': CHECKPOINT_DIR
6464
}

train.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
task_cfg = TASK_CONFIG
1919
train_cfg = TRAIN_CONFIG
2020
policy_config = POLICY_CONFIG
21+
checkpoint_dir = os.path.join(train_cfg['checkpoint_dir'], task)
2122

2223
# device
2324
device = os.environ['DEVICE']
@@ -54,7 +55,7 @@ def train_bc(train_dataloader, val_dataloader, policy_config):
5455
optimizer = make_optimizer(policy_config['policy_class'], policy)
5556

5657
# create checkpoint dir if not exists
57-
os.makedirs(train_cfg['checkpoint_dir'], exist_ok=True)
58+
os.makedirs(checkpoint_dir, exist_ok=True)
5859

5960
train_history = []
6061
validation_history = []
@@ -102,19 +103,19 @@ def train_bc(train_dataloader, val_dataloader, policy_config):
102103
print(summary_string)
103104

104105
if epoch % 200 == 0:
105-
ckpt_path = os.path.join(train_cfg['checkpoint_dir'], f"policy_epoch_{epoch}_seed_{train_cfg['seed']}.ckpt")
106+
ckpt_path = os.path.join(checkpoint_dir, f"policy_epoch_{epoch}_seed_{train_cfg['seed']}.ckpt")
106107
torch.save(policy.state_dict(), ckpt_path)
107-
plot_history(train_history, validation_history, epoch, train_cfg['checkpoint_dir'], train_cfg['seed'])
108+
plot_history(train_history, validation_history, epoch, checkpoint_dir, train_cfg['seed'])
108109

109-
ckpt_path = os.path.join(train_cfg['checkpoint_dir'], f'policy_last.ckpt')
110+
ckpt_path = os.path.join(checkpoint_dir, f'policy_last.ckpt')
110111
torch.save(policy.state_dict(), ckpt_path)
111112

112113

113114
if __name__ == '__main__':
114115
# set seed
115116
set_seed(train_cfg['seed'])
116117
# create ckpt dir if not exists
117-
os.makedirs(train_cfg['checkpoint_dir'], exist_ok=True)
118+
os.makedirs(checkpoint_dir, exist_ok=True)
118119
# number of training episodes
119120
data_dir = os.path.join(task_cfg['dataset_dir'], task)
120121
num_episodes = len(os.listdir(data_dir))
@@ -123,7 +124,7 @@ def train_bc(train_dataloader, val_dataloader, policy_config):
123124
train_dataloader, val_dataloader, stats, _ = load_data(data_dir, num_episodes, task_cfg['camera_names'],
124125
train_cfg['batch_size_train'], train_cfg['batch_size_val'])
125126
# save stats
126-
stats_path = os.path.join(train_cfg['checkpoint_dir'], f'dataset_stats.pkl')
127+
stats_path = os.path.join(checkpoint_dir, f'dataset_stats.pkl')
127128
with open(stats_path, 'wb') as f:
128129
pickle.dump(stats, f)
129130

0 commit comments

Comments
 (0)