Skip to content

Commit 9916386

Browse files
committed
pass task name as argument from cli
1 parent 7cc006a commit 9916386

File tree

4 files changed

+94
-26
lines changed

4 files changed

+94
-26
lines changed

config/config.py

+10-11
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# device
1313
device = 'cpu'
1414
if torch.cuda.is_available(): device = 'cuda'
15-
if torch.backends.mps.is_available(): device = 'mps'
15+
#if torch.backends.mps.is_available(): device = 'mps'
1616
os.environ['DEVICE'] = device
1717

1818
# robot port names
@@ -24,16 +24,15 @@
2424

2525
# task config (you can add new tasks)
2626
TASK_CONFIG = {
27-
'sort':{
28-
'dataset_dir': os.path.join(DATA_DIR, 'sort'),
29-
'num_episodes': 20,
30-
'episode_len': 300,
31-
'state_dim': 5,
32-
'action_dim': 5,
33-
'cam_width': 640,
34-
'cam_height': 480,
35-
'camera_names': ['front']
36-
}
27+
'dataset_dir': DATA_DIR,
28+
'num_episodes': 1,
29+
'episode_len': 350,
30+
'state_dim': 5,
31+
'action_dim': 5,
32+
'cam_width': 640,
33+
'cam_height': 480,
34+
'camera_names': ['front'],
35+
'camera_port': 0
3736
}
3837

3938

evaluate.py

+62-6
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,21 @@
44
import cv2
55
import torch
66
import pickle
7+
import argparse
8+
from time import time
79

10+
from robot import Robot
811
from training.utils import *
9-
from robot import PhysicalRobot
1012

1113

12-
task = 'sort'
13-
cfg = TASK_CONFIG[task]
14+
# parse the task name via command line
15+
parser = argparse.ArgumentParser()
16+
parser.add_argument('--task', type=str, default='task1')
17+
args = parser.parse_args()
18+
task = args.task
19+
20+
# config
21+
cfg = TASK_CONFIG
1422
policy_config = POLICY_CONFIG
1523
train_cfg = TRAIN_CONFIG
1624
device = os.environ['DEVICE']
@@ -33,12 +41,12 @@ def capture_image(cam):
3341

3442
if __name__ == "__main__":
3543
# init camera
36-
cam = cv2.VideoCapture(0)
44+
cam = cv2.VideoCapture(cfg['camera_port'])
3745
# Check if the camera opened successfully
3846
if not cam.isOpened():
3947
raise IOError("Cannot open camera")
4048
# init follower
41-
follower = PhysicalRobot(device_name=ROBOT_PORTS['follower'])
49+
follower = Robot(device_name=ROBOT_PORTS['follower'])
4250

4351
# load the policy
4452
ckpt_path = os.path.join(train_cfg['checkpoint_dir'], train_cfg['eval_ckpt_name'])
@@ -80,6 +88,9 @@ def capture_image(cam):
8088
all_time_actions = torch.zeros([cfg['episode_len'], cfg['episode_len']+num_queries, cfg['state_dim']]).to(device)
8189
qpos_history = torch.zeros((1, cfg['episode_len'], cfg['state_dim'])).to(device)
8290
with torch.inference_mode():
91+
# init buffers
92+
obs_replay = []
93+
action_replay = []
8394
for t in range(cfg['episode_len']):
8495
qpos_numpy = np.array(obs['qpos'])
8596
qpos = pre_process(qpos_numpy)
@@ -106,7 +117,6 @@ def capture_image(cam):
106117
raw_action = raw_action.squeeze(0).cpu().numpy()
107118
action = post_process(raw_action)
108119
action = pos2pwm(action).astype(int)
109-
print(action)
110120
### take action
111121
follower.set_goal_pos(action)
112122

@@ -116,8 +126,54 @@ def capture_image(cam):
116126
'qvel': vel2pwm(follower.read_velocity()),
117127
'images': {cn: capture_image(cam) for cn in cfg['camera_names']}
118128
}
129+
### store data
130+
obs_replay.append(obs)
131+
action_replay.append(action)
119132

120133
os.system('say "stop"')
134+
135+
# create a dictionary to store the data
136+
data_dict = {
137+
'/observations/qpos': [],
138+
'/observations/qvel': [],
139+
'/action': [],
140+
}
141+
# there may be more than one camera
142+
for cam_name in cfg['camera_names']:
143+
data_dict[f'/observations/images/{cam_name}'] = []
144+
145+
# store the observations and actions
146+
for o, a in zip(obs_replay, action_replay):
147+
data_dict['/observations/qpos'].append(o['qpos'])
148+
data_dict['/observations/qvel'].append(o['qvel'])
149+
data_dict['/action'].append(a)
150+
# store the images
151+
for cam_name in cfg['camera_names']:
152+
data_dict[f'/observations/images/{cam_name}'].append(o['images'][cam_name])
153+
154+
t0 = time()
155+
max_timesteps = len(data_dict['/observations/qpos'])
156+
# create data dir if it doesn't exist
157+
data_dir = cfg['dataset_dir']
158+
if not os.path.exists(data_dir): os.makedirs(data_dir)
159+
# count number of files in the directory
160+
idx = len([name for name in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, name))])
161+
dataset_path = os.path.join(data_dir, f'episode_{idx}')
162+
# save the data
163+
with h5py.File("data/demo/trained.hdf5", 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
164+
root.attrs['sim'] = True
165+
obs = root.create_group('observations')
166+
image = obs.create_group('images')
167+
for cam_name in cfg['camera_names']:
168+
_ = image.create_dataset(cam_name, (max_timesteps, cfg['cam_height'], cfg['cam_width'], 3), dtype='uint8',
169+
chunks=(1, cfg['cam_height'], cfg['cam_width'], 3), )
170+
qpos = obs.create_dataset('qpos', (max_timesteps, cfg['state_dim']))
171+
qvel = obs.create_dataset('qvel', (max_timesteps, cfg['state_dim']))
172+
# image = obs.create_dataset("image", (max_timesteps, 240, 320, 3), dtype='uint8', chunks=(1, 240, 320, 3))
173+
action = root.create_dataset('action', (max_timesteps, cfg['action_dim']))
174+
175+
for name, array in data_dict.items():
176+
root[name][...] = array
121177

122178
# disable torque
123179
follower._disable_torque()

record_episodes.py

+10-4
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22
import os
33
import cv2
44
import h5py
5+
import argparse
56
from tqdm import tqdm
67
from time import sleep, time
78
from training.utils import pwm2pos, pwm2vel
89

910
from robot import Robot
1011

11-
task = "sort"
12-
cfg = TASK_CONFIG[task]
12+
# parse the task name via command line
13+
parser = argparse.ArgumentParser()
14+
parser.add_argument('--task', type=str, default='task1')
15+
args = parser.parse_args()
16+
task = args.task
17+
18+
cfg = TASK_CONFIG
1319

1420

1521
def capture_image(cam):
@@ -30,7 +36,7 @@ def capture_image(cam):
3036

3137
if __name__ == "__main__":
3238
# init camera
33-
cam = cv2.VideoCapture(0)
39+
cam = cv2.VideoCapture(cfg['camera_port'])
3440
# Check if the camera opened successfully
3541
if not cam.isOpened():
3642
raise IOError("Cannot open camera")
@@ -97,7 +103,7 @@ def capture_image(cam):
97103
t0 = time()
98104
max_timesteps = len(data_dict['/observations/qpos'])
99105
# create data dir if it doesn't exist
100-
data_dir = cfg['dataset_dir']
106+
data_dir = os.path.join(cfg['dataset_dir'], task)
101107
if not os.path.exists(data_dir): os.makedirs(data_dir)
102108
# count number of files in the directory
103109
idx = len([name for name in os.listdir(data_dir) if os.path.isfile(os.path.join(data_dir, name))])

train.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,20 @@
22

33
import os
44
import pickle
5+
import argparse
56
from copy import deepcopy
67
import matplotlib.pyplot as plt
78

89
from training.utils import *
910

11+
# parse the task name via command line
12+
parser = argparse.ArgumentParser()
13+
parser.add_argument('--task', type=str, default='task1')
14+
args = parser.parse_args()
15+
task = args.task
16+
1017
# configs
11-
task = 'sort'
12-
task_cfg = TASK_CONFIG[task]
18+
task_cfg = TASK_CONFIG
1319
train_cfg = TRAIN_CONFIG
1420
policy_config = POLICY_CONFIG
1521

@@ -109,11 +115,12 @@ def train_bc(train_dataloader, val_dataloader, policy_config):
109115
set_seed(train_cfg['seed'])
110116
# create ckpt dir if not exists
111117
os.makedirs(train_cfg['checkpoint_dir'], exist_ok=True)
112-
# number of training episodes
113-
num_episodes = len(os.listdir(task_cfg['dataset_dir']))
118+
# number of training episodes
119+
data_dir = os.path.join(task_cfg['dataset_dir'], task)
120+
num_episodes = len(os.listdir(data_dir))
114121

115122
# load data
116-
train_dataloader, val_dataloader, stats, _ = load_data(task_cfg['dataset_dir'], num_episodes, task_cfg['camera_names'],
123+
train_dataloader, val_dataloader, stats, _ = load_data(data_dir, num_episodes, task_cfg['camera_names'],
117124
train_cfg['batch_size_train'], train_cfg['batch_size_val'])
118125
# save stats
119126
stats_path = os.path.join(train_cfg['checkpoint_dir'], f'dataset_stats.pkl')

0 commit comments

Comments
 (0)