4
4
import cv2
5
5
import torch
6
6
import pickle
7
+ import argparse
8
+ from time import time
7
9
10
+ from robot import Robot
8
11
from training .utils import *
9
- from robot import PhysicalRobot
10
12
11
13
12
- task = 'sort'
13
- cfg = TASK_CONFIG [task ]
14
+ # parse the task name via command line
15
+ parser = argparse .ArgumentParser ()
16
+ parser .add_argument ('--task' , type = str , default = 'task1' )
17
+ args = parser .parse_args ()
18
+ task = args .task
19
+
20
+ # config
21
+ cfg = TASK_CONFIG
14
22
policy_config = POLICY_CONFIG
15
23
train_cfg = TRAIN_CONFIG
16
24
device = os .environ ['DEVICE' ]
@@ -33,12 +41,12 @@ def capture_image(cam):
33
41
34
42
if __name__ == "__main__" :
35
43
# init camera
36
- cam = cv2 .VideoCapture (0 )
44
+ cam = cv2 .VideoCapture (cfg [ 'camera_port' ] )
37
45
# Check if the camera opened successfully
38
46
if not cam .isOpened ():
39
47
raise IOError ("Cannot open camera" )
40
48
# init follower
41
- follower = PhysicalRobot (device_name = ROBOT_PORTS ['follower' ])
49
+ follower = Robot (device_name = ROBOT_PORTS ['follower' ])
42
50
43
51
# load the policy
44
52
ckpt_path = os .path .join (train_cfg ['checkpoint_dir' ], train_cfg ['eval_ckpt_name' ])
@@ -80,6 +88,9 @@ def capture_image(cam):
80
88
all_time_actions = torch .zeros ([cfg ['episode_len' ], cfg ['episode_len' ]+ num_queries , cfg ['state_dim' ]]).to (device )
81
89
qpos_history = torch .zeros ((1 , cfg ['episode_len' ], cfg ['state_dim' ])).to (device )
82
90
with torch .inference_mode ():
91
+ # init buffers
92
+ obs_replay = []
93
+ action_replay = []
83
94
for t in range (cfg ['episode_len' ]):
84
95
qpos_numpy = np .array (obs ['qpos' ])
85
96
qpos = pre_process (qpos_numpy )
@@ -106,7 +117,6 @@ def capture_image(cam):
106
117
raw_action = raw_action .squeeze (0 ).cpu ().numpy ()
107
118
action = post_process (raw_action )
108
119
action = pos2pwm (action ).astype (int )
109
- print (action )
110
120
### take action
111
121
follower .set_goal_pos (action )
112
122
@@ -116,8 +126,54 @@ def capture_image(cam):
116
126
'qvel' : vel2pwm (follower .read_velocity ()),
117
127
'images' : {cn : capture_image (cam ) for cn in cfg ['camera_names' ]}
118
128
}
129
+ ### store data
130
+ obs_replay .append (obs )
131
+ action_replay .append (action )
119
132
120
133
os .system ('say "stop"' )
134
+
135
+ # create a dictionary to store the data
136
+ data_dict = {
137
+ '/observations/qpos' : [],
138
+ '/observations/qvel' : [],
139
+ '/action' : [],
140
+ }
141
+ # there may be more than one camera
142
+ for cam_name in cfg ['camera_names' ]:
143
+ data_dict [f'/observations/images/{ cam_name } ' ] = []
144
+
145
+ # store the observations and actions
146
+ for o , a in zip (obs_replay , action_replay ):
147
+ data_dict ['/observations/qpos' ].append (o ['qpos' ])
148
+ data_dict ['/observations/qvel' ].append (o ['qvel' ])
149
+ data_dict ['/action' ].append (a )
150
+ # store the images
151
+ for cam_name in cfg ['camera_names' ]:
152
+ data_dict [f'/observations/images/{ cam_name } ' ].append (o ['images' ][cam_name ])
153
+
154
+ t0 = time ()
155
+ max_timesteps = len (data_dict ['/observations/qpos' ])
156
+ # create data dir if it doesn't exist
157
+ data_dir = cfg ['dataset_dir' ]
158
+ if not os .path .exists (data_dir ): os .makedirs (data_dir )
159
+ # count number of files in the directory
160
+ idx = len ([name for name in os .listdir (data_dir ) if os .path .isfile (os .path .join (data_dir , name ))])
161
+ dataset_path = os .path .join (data_dir , f'episode_{ idx } ' )
162
+ # save the data
163
+ with h5py .File ("data/demo/trained.hdf5" , 'w' , rdcc_nbytes = 1024 ** 2 * 2 ) as root :
164
+ root .attrs ['sim' ] = True
165
+ obs = root .create_group ('observations' )
166
+ image = obs .create_group ('images' )
167
+ for cam_name in cfg ['camera_names' ]:
168
+ _ = image .create_dataset (cam_name , (max_timesteps , cfg ['cam_height' ], cfg ['cam_width' ], 3 ), dtype = 'uint8' ,
169
+ chunks = (1 , cfg ['cam_height' ], cfg ['cam_width' ], 3 ), )
170
+ qpos = obs .create_dataset ('qpos' , (max_timesteps , cfg ['state_dim' ]))
171
+ qvel = obs .create_dataset ('qvel' , (max_timesteps , cfg ['state_dim' ]))
172
+ # image = obs.create_dataset("image", (max_timesteps, 240, 320, 3), dtype='uint8', chunks=(1, 240, 320, 3))
173
+ action = root .create_dataset ('action' , (max_timesteps , cfg ['action_dim' ]))
174
+
175
+ for name , array in data_dict .items ():
176
+ root [name ][...] = array
121
177
122
178
# disable torque
123
179
follower ._disable_torque ()
0 commit comments