|
24 | 24 |
|
25 | 25 | @partial(jax.jit, static_argnames=("steps", "min_logprob_std")) |
26 | 26 | def jit_compute_chain_log_probs( |
27 | | - actor: ContinuousDDPM, obs: jnp.ndarray, chain: jnp.ndarray, |
28 | | - steps: int, min_logprob_std: float, |
| 27 | + actor: ContinuousDDPM, |
| 28 | + obs: jnp.ndarray, |
| 29 | + chain: jnp.ndarray, |
| 30 | + steps: int, |
| 31 | + min_logprob_std: float, |
29 | 32 | ) -> jnp.ndarray: |
30 | 33 | ts = quad_t_schedule(steps, n=actor.t_schedule_n, |
31 | 34 | tmin=actor.t_diffusion[0], tmax=actor.t_diffusion[1]) |
@@ -57,6 +60,24 @@ def step_fn(_, i): |
57 | 60 | return jnp.transpose(step_lps, (1, 0)) |
58 | 61 |
|
59 | 62 |
|
| 63 | +@partial(jax.jit, static_argnames=("steps", "min_logprob_std")) |
| 64 | +def jit_sample_actions( |
| 65 | + rng: PRNGKey, actor: ContinuousDDPM, obs: jnp.ndarray, |
| 66 | + steps: int, min_logprob_std: float, |
| 67 | +) -> Tuple[PRNGKey, jnp.ndarray, jnp.ndarray, jnp.ndarray]: |
| 68 | + B = obs.shape[0] |
| 69 | + rng, xT_rng = jax.random.split(rng) |
| 70 | + xT = jax.random.normal(xT_rng, (B, actor.x_dim)) |
| 71 | + rng, action, history = actor.sample( |
| 72 | + rng, xT, condition=obs, training=False, solver="ddpm", |
| 73 | + ) |
| 74 | + chain = jnp.transpose( |
| 75 | + jnp.concatenate([history[0], action[jnp.newaxis]], axis=0), (1, 0, 2)) |
| 76 | + step_lps = jit_compute_chain_log_probs(actor, obs, chain, steps, min_logprob_std) |
| 77 | + log_prob = step_lps.mean(axis=-1, keepdims=True) |
| 78 | + return rng, action, chain, log_prob |
| 79 | + |
| 80 | + |
60 | 81 | @partial(jax.jit, static_argnames=( |
61 | 82 | "gamma", "gae_lambda", "gamma_denoising", |
62 | 83 | "clip_epsilon", "clip_epsilon_base", "clip_epsilon_rate", |
@@ -246,20 +267,11 @@ def sample_actions( |
246 | 267 | self, obs: jnp.ndarray, deterministic: bool = True, num_samples: int = 1, |
247 | 268 | ) -> Tuple[jnp.ndarray, Metric]: |
248 | 269 | assert num_samples == 1, "DPPO only supports num_samples=1" |
249 | | - B = obs.shape[0] |
250 | | - self.rng, xT_rng = jax.random.split(self.rng) |
251 | | - xT = jax.random.normal(xT_rng, (B, self.act_dim)) |
252 | | - |
253 | | - self.rng, action, history = self.actor.sample( |
254 | | - self.rng, xT, condition=obs, training=False, solver="ddpm", |
255 | | - ) |
256 | | - chain = jnp.transpose( |
257 | | - jnp.concatenate([history[0], action[jnp.newaxis]], axis=0), (1, 0, 2)) |
258 | | - |
259 | | - step_lps = jit_compute_chain_log_probs( |
260 | | - self.actor, obs, chain, |
261 | | - self.cfg.diffusion.steps, self.cfg.diffusion.min_logprob_denoising_std, |
| 270 | + self.rng, action, chain, log_prob = jit_sample_actions( |
| 271 | + self.rng, |
| 272 | + self.actor, |
| 273 | + obs, |
| 274 | + self.cfg.diffusion.steps, |
| 275 | + self.cfg.diffusion.min_logprob_denoising_std, |
262 | 276 | ) |
263 | | - log_prob = step_lps.mean(axis=-1, keepdims=True) |
264 | | - |
265 | 277 | return action, {"log_prob": log_prob, "action_chains": chain} |
0 commit comments