|
| 1 | +from functools import partial |
| 2 | + |
| 3 | +import jax |
| 4 | +import jax.numpy as jnp |
| 5 | +import optax |
| 6 | + |
| 7 | +from flowrl.agent.base import BaseAgent |
| 8 | +from flowrl.agent.online.ppo import compute_gae |
| 9 | +from flowrl.config.online.algo.genpo import GenPOConfig |
| 10 | +from flowrl.flow.cnf import FlowBackbone |
| 11 | +from flowrl.flow.genpo import GenPOFlow |
| 12 | +from flowrl.functional.activation import get_activation |
| 13 | +from flowrl.module.critic import ScalarCritic |
| 14 | +from flowrl.module.mlp import MLP |
| 15 | +from flowrl.module.model import Model |
| 16 | +from flowrl.module.simba import Simba |
| 17 | +from flowrl.module.time_embedding import LearnableFourierEmbedding |
| 18 | +from flowrl.types import Metric, PRNGKey, RolloutBatch |
| 19 | + |
| 20 | +# ======= Sampling ======== |
| 21 | + |
| 22 | +@partial(jax.jit, static_argnames=("deterministic",)) |
| 23 | +def jit_sample_action_genpo(rng, actor, obs, deterministic): |
| 24 | + rng, x0_rng = jax.random.split(rng) |
| 25 | + B = obs.shape[0] |
| 26 | + aug_dim = 2 * actor.a_dim |
| 27 | + |
| 28 | + if deterministic: |
| 29 | + x0 = jnp.zeros((B, aug_dim)) |
| 30 | + x1 = actor.forward(obs, x0) |
| 31 | + action = x1[:, :actor.a_dim] |
| 32 | + log_prob = jnp.zeros(B) |
| 33 | + else: |
| 34 | + x0 = jax.random.normal(x0_rng, (B, aug_dim)) |
| 35 | + x1 = actor.forward(obs, x0) |
| 36 | + log_prob = actor.log_prob(obs, x0) |
| 37 | + action = x1[:, :actor.a_dim] |
| 38 | + |
| 39 | + return action, log_prob, x1 |
| 40 | + |
| 41 | + |
| 42 | +# ======= Training Update ======== |
| 43 | + |
| 44 | +@partial(jax.jit, static_argnames=( |
| 45 | + "gamma", "gae_lambda", "clip_epsilon", "entropy_coeff", "compress_coef", |
| 46 | + "reward_scaling", "normalize_advantage", |
| 47 | + "num_epochs", "num_minibatches", "batch_size", |
| 48 | +)) |
| 49 | +def jit_update_genpo( |
| 50 | + rng, actor, critic, rollout, |
| 51 | + gamma, gae_lambda, clip_epsilon, entropy_coeff, compress_coef, |
| 52 | + reward_scaling, normalize_advantage, |
| 53 | + num_epochs, num_minibatches, batch_size, |
| 54 | +): |
| 55 | + T, B = rollout.rewards.shape[:2] |
| 56 | + |
| 57 | + value_pred = critic(rollout.obs) |
| 58 | + next_value_pred = critic(rollout.next_obs) |
| 59 | + |
| 60 | + gae_vs, gae_advantages = jax.lax.stop_gradient( |
| 61 | + compute_gae( |
| 62 | + terminated=rollout.terminated, truncated=rollout.truncated, |
| 63 | + rewards=rollout.rewards * reward_scaling, |
| 64 | + values=value_pred, next_values=next_value_pred, |
| 65 | + gae_lambda=gae_lambda, gamma=gamma, |
| 66 | + ) |
| 67 | + ) |
| 68 | + |
| 69 | + if normalize_advantage: |
| 70 | + gae_advantages = (gae_advantages - gae_advantages.mean()) / (gae_advantages.std() + 1e-8) |
| 71 | + |
| 72 | + flat_obs = rollout.obs.reshape(T * B, -1) |
| 73 | + flat_actions = rollout.actions.reshape(T * B, -1) |
| 74 | + flat_advantages = gae_advantages.reshape(T * B, 1) |
| 75 | + flat_gae_vs = gae_vs.reshape(T * B, 1) |
| 76 | + flat_truncations = rollout.truncated.reshape(T * B, 1) |
| 77 | + flat_old_log_probs = rollout.extras["log_prob"].reshape(T * B, 1) |
| 78 | + flat_aug_actions = rollout.extras["aug_action"].reshape(T * B, -1) |
| 79 | + |
| 80 | + def epoch_step(carry, _): |
| 81 | + rng, actor, critic = carry |
| 82 | + rng, perm_rng = jax.random.split(rng) |
| 83 | + |
| 84 | + perm = jax.random.permutation(perm_rng, T * B) |
| 85 | + total = num_minibatches * batch_size |
| 86 | + perm = perm[:total] |
| 87 | + mb_indices = perm.reshape(num_minibatches, batch_size) |
| 88 | + |
| 89 | + def minibatch_step(carry, indices): |
| 90 | + rng, actor, critic = carry |
| 91 | + rng, compress_rng = jax.random.split(rng) |
| 92 | + |
| 93 | + mb_obs = flat_obs[indices] |
| 94 | + mb_advantages = flat_advantages[indices] |
| 95 | + mb_gae_vs = flat_gae_vs[indices] |
| 96 | + mb_truncations = flat_truncations[indices] |
| 97 | + mb_old_log_probs = flat_old_log_probs[indices] |
| 98 | + mb_aug_actions = flat_aug_actions[indices] |
| 99 | + |
| 100 | + compress_x0 = jax.random.normal(compress_rng, (batch_size, 2 * actor.a_dim)) |
| 101 | + |
| 102 | + def actor_loss_fn(actor_params, dropout_rng): |
| 103 | + # Log prob via inverse Jacobian (matches official impl) |
| 104 | + new_log_prob = actor.log_prob_via_inverse( |
| 105 | + mb_obs, mb_aug_actions, params=actor_params, |
| 106 | + )[:, jnp.newaxis] # (batch_size, 1) |
| 107 | + |
| 108 | + # PPO clipped surrogate |
| 109 | + rho_s = jnp.exp(new_log_prob - mb_old_log_probs) |
| 110 | + surrogate1 = rho_s * mb_advantages |
| 111 | + surrogate2 = jnp.clip(rho_s, 1 - clip_epsilon, 1 + clip_epsilon) * mb_advantages |
| 112 | + policy_loss = -jnp.mean(jnp.minimum(surrogate1, surrogate2)) |
| 113 | + |
| 114 | + # Entropy loss (maximize entropy = minimize mean log_prob) |
| 115 | + if entropy_coeff > 0: |
| 116 | + entropy_loss = entropy_coeff * jnp.mean(new_log_prob) |
| 117 | + else: |
| 118 | + entropy_loss = 0 |
| 119 | + |
| 120 | + # Compress loss on fresh forward samples (L2 norm, matches official) |
| 121 | + if compress_coef > 0: |
| 122 | + x1_fresh = actor.forward(mb_obs, compress_x0, params=actor_params) |
| 123 | + z_f = x1_fresh[:, :actor.a_dim] |
| 124 | + y_f = x1_fresh[:, actor.a_dim:] |
| 125 | + compress_loss = compress_coef * jnp.mean( |
| 126 | + jnp.sqrt(jnp.sum((z_f - y_f) ** 2, axis=-1) + 1e-8) |
| 127 | + ) |
| 128 | + else: |
| 129 | + compress_loss = 0 |
| 130 | + |
| 131 | + total_loss = policy_loss + entropy_loss + compress_loss |
| 132 | + |
| 133 | + return total_loss, { |
| 134 | + "loss/policy_loss": policy_loss, |
| 135 | + "loss/entropy_loss": entropy_loss, |
| 136 | + "loss/compress_loss": compress_loss, |
| 137 | + "misc/entropy": -jnp.mean(new_log_prob), |
| 138 | + "misc/policy_ratio": jnp.mean(rho_s), |
| 139 | + "misc/clipped_ratio": jnp.mean(jnp.abs(rho_s - 1.0) > clip_epsilon), |
| 140 | + "misc/log_prob_mean": jnp.mean(new_log_prob), |
| 141 | + } |
| 142 | + |
| 143 | + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) |
| 144 | + |
| 145 | + def critic_loss_fn(critic_params, dropout_rng): |
| 146 | + v = critic.apply( |
| 147 | + {"params": critic_params}, mb_obs, |
| 148 | + training=True, rngs={"dropout": dropout_rng}, |
| 149 | + ) |
| 150 | + v_error = (mb_gae_vs - v) * (1 - mb_truncations) |
| 151 | + v_loss = jnp.mean(v_error ** 2) |
| 152 | + return v_loss, { |
| 153 | + "loss/value_loss": v_loss, |
| 154 | + "misc/value_mean": jnp.mean(v), |
| 155 | + } |
| 156 | + |
| 157 | + new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) |
| 158 | + |
| 159 | + metrics = {**actor_metrics, **critic_metrics} |
| 160 | + return (rng, new_actor, new_critic), metrics |
| 161 | + |
| 162 | + (rng, actor, critic), mb_metrics = jax.lax.scan( |
| 163 | + minibatch_step, init=(rng, actor, critic), xs=mb_indices, |
| 164 | + ) |
| 165 | + return (rng, actor, critic), mb_metrics |
| 166 | + |
| 167 | + (rng, actor, critic), all_metrics = jax.lax.scan( |
| 168 | + epoch_step, init=(rng, actor, critic), length=num_epochs, |
| 169 | + ) |
| 170 | + |
| 171 | + metrics = jax.tree.map(lambda x: x.mean(), all_metrics) |
| 172 | + metrics.update({ |
| 173 | + "misc/reward_mean": rollout.rewards.mean(), |
| 174 | + "misc/obs_mean": flat_obs.mean(), |
| 175 | + "misc/obs_std": flat_obs.std(axis=0).mean(), |
| 176 | + "misc/action_l1_mean": jnp.abs(flat_actions).mean(), |
| 177 | + "misc/advantages_mean": flat_advantages.mean(), |
| 178 | + "misc/advantages_std": flat_advantages.std(axis=0).mean(), |
| 179 | + }) |
| 180 | + |
| 181 | + return rng, actor, critic, metrics |
| 182 | + |
| 183 | + |
| 184 | +# ======= Agent ======== |
| 185 | + |
| 186 | +class GenPOAgent(BaseAgent): |
| 187 | + """ |
| 188 | + Generative Policy Optimization (GenPO) |
| 189 | + Uses coupled Heun solver on augmented action space with exact log-probability |
| 190 | + via Jacobian determinant, combined with PPO clipping. |
| 191 | + """ |
| 192 | + name = "GenPOAgent" |
| 193 | + model_names = ["actor", "critic"] |
| 194 | + |
| 195 | + def __init__(self, obs_dim, act_dim, cfg: GenPOConfig, seed): |
| 196 | + super().__init__(obs_dim, act_dim, cfg, seed) |
| 197 | + self.cfg = cfg |
| 198 | + self.rng, actor_rng, critic_rng = jax.random.split(self.rng, 3) |
| 199 | + |
| 200 | + actor_activation = get_activation(cfg.flow.activation) |
| 201 | + critic_activation = get_activation(cfg.critic_activation) |
| 202 | + backbone_cls = {"mlp": MLP, "simba": Simba}[cfg.backbone_cls] |
| 203 | + |
| 204 | + flow_backbone = FlowBackbone( |
| 205 | + vel_predictor=backbone_cls( |
| 206 | + hidden_dims=cfg.flow.hidden_dims, |
| 207 | + activation=actor_activation, |
| 208 | + output_dim=act_dim, |
| 209 | + ), |
| 210 | + time_embedding=LearnableFourierEmbedding(output_dim=cfg.flow.time_dim), |
| 211 | + ) |
| 212 | + self.actor = GenPOFlow.create( |
| 213 | + network=flow_backbone, |
| 214 | + rng=actor_rng, |
| 215 | + inputs=( |
| 216 | + jnp.ones((1, act_dim)), |
| 217 | + jnp.ones((1, 1)), |
| 218 | + jnp.ones((1, obs_dim)), |
| 219 | + ), |
| 220 | + a_dim=act_dim, |
| 221 | + steps=cfg.flow.steps, |
| 222 | + mix_para=cfg.flow.mix_para, |
| 223 | + optimizer=optax.adam(learning_rate=cfg.flow.lr), |
| 224 | + clip_grad_norm=cfg.clip_grad_norm, |
| 225 | + ) |
| 226 | + |
| 227 | + critic_def = ScalarCritic( |
| 228 | + backbone=backbone_cls( |
| 229 | + hidden_dims=cfg.critic_hidden_dims, |
| 230 | + activation=critic_activation, |
| 231 | + ), |
| 232 | + ) |
| 233 | + self.critic = Model.create( |
| 234 | + critic_def, critic_rng, |
| 235 | + inputs=(jnp.ones((1, obs_dim)),), |
| 236 | + optimizer=optax.adam(learning_rate=cfg.critic_lr), |
| 237 | + clip_grad_norm=cfg.clip_grad_norm, |
| 238 | + ) |
| 239 | + |
| 240 | + def train_step(self, rollout: RolloutBatch, step: int) -> Metric: |
| 241 | + self.rng, self.actor, self.critic, metrics = jit_update_genpo( |
| 242 | + self.rng, self.actor, self.critic, rollout, |
| 243 | + gamma=self.cfg.gamma, |
| 244 | + gae_lambda=self.cfg.gae_lambda, |
| 245 | + clip_epsilon=self.cfg.clip_epsilon, |
| 246 | + entropy_coeff=self.cfg.entropy_coeff, |
| 247 | + compress_coef=self.cfg.compress_coef, |
| 248 | + reward_scaling=self.cfg.reward_scaling, |
| 249 | + normalize_advantage=self.cfg.normalize_advantage, |
| 250 | + num_epochs=self.cfg.num_epochs, |
| 251 | + num_minibatches=self.cfg.num_minibatches, |
| 252 | + batch_size=self.cfg.batch_size, |
| 253 | + ) |
| 254 | + return metrics |
| 255 | + |
| 256 | + def sample_actions(self, obs, deterministic=True, num_samples=1): |
| 257 | + assert num_samples == 1 |
| 258 | + self.rng, sample_rng = jax.random.split(self.rng) |
| 259 | + |
| 260 | + action, log_prob, aug_action = jit_sample_action_genpo( |
| 261 | + sample_rng, self.actor, obs, deterministic, |
| 262 | + ) |
| 263 | + |
| 264 | + return action, { |
| 265 | + "log_prob": log_prob[:, jnp.newaxis], |
| 266 | + "aug_action": aug_action, |
| 267 | + } |
0 commit comments