Skip to content

Commit 2e5fcc1

Browse files
committed
update: dppo
1 parent dadd7ec commit 2e5fcc1

4 files changed

Lines changed: 38 additions & 18 deletions

File tree

examples/online/config/hb_onpolicy/algo/dppo.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,4 @@ algo:
3333
x_min: -1.0
3434
x_max: 1.0
3535
solver: ddpm
36-
min_sampling_denoising_std: 0.05
3736
min_logprob_denoising_std: 0.1

examples/online/config/isaaclab_onpolicy/algo/dppo.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,5 +34,4 @@ algo:
3434
x_min: -1.0
3535
x_max: 1.0
3636
solver: ddpm
37-
min_sampling_denoising_std: 0.05
3837
min_logprob_denoising_std: 0.1

flowrl/agent/online/dppo.py

Lines changed: 38 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -74,24 +74,43 @@ def jit_sample_actions(
7474
chain = jnp.transpose(
7575
jnp.concatenate([history[0], action[jnp.newaxis]], axis=0), (1, 0, 2))
7676
step_lps = jit_compute_chain_log_probs(actor, obs, chain, steps, min_logprob_std)
77-
log_prob = step_lps.mean(axis=-1, keepdims=True)
77+
log_prob = step_lps.mean(axis=-1, keepdims=True) # not used in DPPO training
7878
return rng, action, chain, log_prob
7979

8080

8181
@partial(jax.jit, static_argnames=(
82-
"gamma", "gae_lambda", "gamma_denoising",
83-
"clip_epsilon", "clip_epsilon_base", "clip_epsilon_rate",
84-
"reward_scaling", "normalize_advantage",
85-
"num_epochs", "num_minibatches", "batch_size",
86-
"denoising_steps", "min_logprob_std",
82+
"gamma",
83+
"gae_lambda",
84+
"gamma_denoising",
85+
"clip_epsilon",
86+
"clip_epsilon_base",
87+
"clip_epsilon_rate",
88+
"reward_scaling",
89+
"normalize_advantage",
90+
"num_epochs",
91+
"num_minibatches",
92+
"batch_size",
93+
"denoising_steps",
94+
"min_logprob_std",
8795
))
8896
def jit_update_dppo(
89-
rng: PRNGKey, actor: ContinuousDDPM, critic: Model, rollout: RolloutBatch,
90-
gamma: float, gae_lambda: float, gamma_denoising: float,
91-
clip_epsilon: float, clip_epsilon_base: float, clip_epsilon_rate: float,
92-
reward_scaling: float, normalize_advantage: bool,
93-
num_epochs: int, num_minibatches: int, batch_size: int,
94-
denoising_steps: int, min_logprob_std: float,
97+
rng: PRNGKey,
98+
actor: ContinuousDDPM,
99+
critic: Model,
100+
rollout: RolloutBatch,
101+
gamma: float,
102+
gae_lambda: float,
103+
gamma_denoising: float,
104+
clip_epsilon: float,
105+
clip_epsilon_base: float,
106+
clip_epsilon_rate: float,
107+
reward_scaling: float,
108+
normalize_advantage: bool,
109+
num_epochs: int,
110+
num_minibatches: int,
111+
batch_size: int,
112+
denoising_steps: int,
113+
min_logprob_std: float,
95114
):
96115
T, B = rollout.rewards.shape[:2]
97116
K = denoising_steps
@@ -161,10 +180,14 @@ def actor_loss_fn(actor_params, dropout_rng):
161180
new_actor, actor_info = actor.apply_gradient(actor_loss_fn)
162181

163182
def critic_loss_fn(critic_params, dropout_rng):
164-
v = critic.apply({"params": critic_params}, mb_obs,
165-
training=True, rngs={"dropout": dropout_rng})
183+
v = critic.apply({"params": critic_params},
184+
mb_obs,
185+
training=True,
186+
rngs={"dropout": dropout_rng})
166187
v_loss = jnp.mean((mb_vs - v) ** 2)
167-
return v_loss, {"loss/value_loss": v_loss, "misc/value_mean": jnp.mean(v)}
188+
return v_loss, {
189+
"loss/value_loss": v_loss,
190+
"misc/value_mean": jnp.mean(v)}
168191

169192
new_critic, critic_info = critic.apply_gradient(critic_loss_fn)
170193
return (rng, new_actor, new_critic), {**actor_info, **critic_info}

flowrl/config/online/algo/dppo.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ class DPPODiffusionConfig:
1515
x_min: float
1616
x_max: float
1717
solver: str
18-
min_sampling_denoising_std: float
1918
min_logprob_denoising_std: float
2019

2120

0 commit comments

Comments
 (0)