Skip to content

Commit dadd7ec

Browse files
committed
update: dppo
1 parent daba10f commit dadd7ec

2 files changed

Lines changed: 72 additions & 17 deletions

File tree

flowrl/agent/online/dppo.py

Lines changed: 29 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,11 @@
2424

2525
@partial(jax.jit, static_argnames=("steps", "min_logprob_std"))
2626
def jit_compute_chain_log_probs(
27-
actor: ContinuousDDPM, obs: jnp.ndarray, chain: jnp.ndarray,
28-
steps: int, min_logprob_std: float,
27+
actor: ContinuousDDPM,
28+
obs: jnp.ndarray,
29+
chain: jnp.ndarray,
30+
steps: int,
31+
min_logprob_std: float,
2932
) -> jnp.ndarray:
3033
ts = quad_t_schedule(steps, n=actor.t_schedule_n,
3134
tmin=actor.t_diffusion[0], tmax=actor.t_diffusion[1])
@@ -57,6 +60,24 @@ def step_fn(_, i):
5760
return jnp.transpose(step_lps, (1, 0))
5861

5962

63+
@partial(jax.jit, static_argnames=("steps", "min_logprob_std"))
64+
def jit_sample_actions(
65+
rng: PRNGKey, actor: ContinuousDDPM, obs: jnp.ndarray,
66+
steps: int, min_logprob_std: float,
67+
) -> Tuple[PRNGKey, jnp.ndarray, jnp.ndarray, jnp.ndarray]:
68+
B = obs.shape[0]
69+
rng, xT_rng = jax.random.split(rng)
70+
xT = jax.random.normal(xT_rng, (B, actor.x_dim))
71+
rng, action, history = actor.sample(
72+
rng, xT, condition=obs, training=False, solver="ddpm",
73+
)
74+
chain = jnp.transpose(
75+
jnp.concatenate([history[0], action[jnp.newaxis]], axis=0), (1, 0, 2))
76+
step_lps = jit_compute_chain_log_probs(actor, obs, chain, steps, min_logprob_std)
77+
log_prob = step_lps.mean(axis=-1, keepdims=True)
78+
return rng, action, chain, log_prob
79+
80+
6081
@partial(jax.jit, static_argnames=(
6182
"gamma", "gae_lambda", "gamma_denoising",
6283
"clip_epsilon", "clip_epsilon_base", "clip_epsilon_rate",
@@ -246,20 +267,11 @@ def sample_actions(
246267
self, obs: jnp.ndarray, deterministic: bool = True, num_samples: int = 1,
247268
) -> Tuple[jnp.ndarray, Metric]:
248269
assert num_samples == 1, "DPPO only supports num_samples=1"
249-
B = obs.shape[0]
250-
self.rng, xT_rng = jax.random.split(self.rng)
251-
xT = jax.random.normal(xT_rng, (B, self.act_dim))
252-
253-
self.rng, action, history = self.actor.sample(
254-
self.rng, xT, condition=obs, training=False, solver="ddpm",
255-
)
256-
chain = jnp.transpose(
257-
jnp.concatenate([history[0], action[jnp.newaxis]], axis=0), (1, 0, 2))
258-
259-
step_lps = jit_compute_chain_log_probs(
260-
self.actor, obs, chain,
261-
self.cfg.diffusion.steps, self.cfg.diffusion.min_logprob_denoising_std,
270+
self.rng, action, chain, log_prob = jit_sample_actions(
271+
self.rng,
272+
self.actor,
273+
obs,
274+
self.cfg.diffusion.steps,
275+
self.cfg.diffusion.min_logprob_denoising_std,
262276
)
263-
log_prob = step_lps.mean(axis=-1, keepdims=True)
264-
265277
return action, {"log_prob": log_prob, "action_chains": chain}

scripts/isaaclab/dppo.sh

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Specify which GPUs to use
2+
GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use
3+
SEEDS=(0 1 2 3)
4+
NUM_EACH_GPU=1
5+
6+
PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]}))
7+
8+
TASKS=(
9+
"Isaac-Ant-v0"
10+
"Isaac-Humanoid-v0"
11+
)
12+
13+
SHARED_ARGS=(
14+
"algo=dppo"
15+
"log.tag=default"
16+
)
17+
18+
run_task() {
19+
task=$1
20+
seed=$2
21+
slot=$3
22+
num_gpus=${#GPUS[@]}
23+
device_idx=$((slot % num_gpus))
24+
device=${GPUS[$device_idx]}
25+
echo "Running $task $seed on GPU $device"
26+
unset CUDA_VISIBLE_DEVICES
27+
export CUDA_VISIBLE_DEVICES=$device
28+
export XLA_PYTHON_CLIENT_PREALLOCATE="false"
29+
command="python3 examples/online/main_isaaclab_onpolicy.py task=$task seed=$seed ${SHARED_ARGS[@]}"
30+
if [ -n "$DRY_RUN" ]; then
31+
echo $command
32+
else
33+
echo $command
34+
$command
35+
fi
36+
}
37+
38+
. env_parallel.bash
39+
if [ -n "$DRY_RUN" ]; then
40+
env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]}
41+
else
42+
env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]}
43+
fi

0 commit comments

Comments
 (0)