Skip to content

Commit 4aa1772

Browse files
committed
feat: GenPO
1 parent 1c627d6 commit 4aa1772

7 files changed

Lines changed: 536 additions & 1 deletion

File tree

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# @package _global_
2+
algo:
3+
name: genpo
4+
num_envs: 1024
5+
rollout_length: 24
6+
backbone_cls: mlp
7+
critic_hidden_dims: [256, 256, 256]
8+
critic_activation: elu
9+
critic_lr: 0.001
10+
gamma: 0.99
11+
gae_lambda: 0.95
12+
clip_epsilon: 0.2
13+
reward_scaling: 1.0
14+
normalize_advantage: true
15+
num_minibatches: 4
16+
num_epochs: 4
17+
batch_size: 6144
18+
clip_grad_norm: 1.0
19+
entropy_coeff: 0.01
20+
compress_coef: 0.01
21+
flow:
22+
activation: elu
23+
hidden_dims: [256, 256, 256]
24+
time_dim: 32
25+
steps: 5
26+
mix_para: 0.9
27+
lr: 0.0001
28+
29+
train_frames: 100_000_000

examples/online/main_isaaclab_onpolicy.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from omegaconf import OmegaConf
1010
from tqdm import tqdm
1111

12-
from flowrl.agent.online import DPPOAgent, FPOAgent, PPOAgent
12+
from flowrl.agent.online import DPPOAgent, FPOAgent, GenPOAgent, PPOAgent
1313
from flowrl.config.online.onpolicy_isaaclab_config import Config
1414
from flowrl.dataset.buffer.state import EmpiricalNormalizer
1515
from flowrl.env.online.isaaclab_env import IsaacLabEnv
@@ -23,6 +23,7 @@
2323
"ppo": PPOAgent,
2424
"dppo": DPPOAgent,
2525
"fpo": FPOAgent,
26+
"genpo": GenPOAgent,
2627
}
2728

2829

flowrl/agent/online/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from .dpmd import DPMDAgent
66
from .dppo import DPPOAgent
77
from .fpo import FPOAgent
8+
from .genpo import GenPOAgent
89
from .idem import IDEMAgent
910
from .nclql import NCLQLAgent
1011
from .ppo import PPOAgent

flowrl/agent/online/genpo.py

Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
from functools import partial
2+
3+
import jax
4+
import jax.numpy as jnp
5+
import optax
6+
7+
from flowrl.agent.base import BaseAgent
8+
from flowrl.agent.online.ppo import compute_gae
9+
from flowrl.config.online.algo.genpo import GenPOConfig
10+
from flowrl.flow.cnf import FlowBackbone
11+
from flowrl.flow.genpo import GenPOFlow
12+
from flowrl.functional.activation import get_activation
13+
from flowrl.module.critic import ScalarCritic
14+
from flowrl.module.mlp import MLP
15+
from flowrl.module.model import Model
16+
from flowrl.module.simba import Simba
17+
from flowrl.module.time_embedding import LearnableFourierEmbedding
18+
from flowrl.types import Metric, PRNGKey, RolloutBatch
19+
20+
# ======= Sampling ========
21+
22+
@partial(jax.jit, static_argnames=("deterministic",))
23+
def jit_sample_action_genpo(rng, actor, obs, deterministic):
24+
rng, x0_rng = jax.random.split(rng)
25+
B = obs.shape[0]
26+
aug_dim = 2 * actor.a_dim
27+
28+
if deterministic:
29+
x0 = jnp.zeros((B, aug_dim))
30+
x1 = actor.forward(obs, x0)
31+
action = x1[:, :actor.a_dim]
32+
log_prob = jnp.zeros(B)
33+
else:
34+
x0 = jax.random.normal(x0_rng, (B, aug_dim))
35+
x1 = actor.forward(obs, x0)
36+
log_prob = actor.log_prob(obs, x0)
37+
action = x1[:, :actor.a_dim]
38+
39+
return action, log_prob, x1
40+
41+
42+
# ======= Training Update ========
43+
44+
@partial(jax.jit, static_argnames=(
45+
"gamma", "gae_lambda", "clip_epsilon", "entropy_coeff", "compress_coef",
46+
"reward_scaling", "normalize_advantage",
47+
"num_epochs", "num_minibatches", "batch_size",
48+
))
49+
def jit_update_genpo(
50+
rng, actor, critic, rollout,
51+
gamma, gae_lambda, clip_epsilon, entropy_coeff, compress_coef,
52+
reward_scaling, normalize_advantage,
53+
num_epochs, num_minibatches, batch_size,
54+
):
55+
T, B = rollout.rewards.shape[:2]
56+
57+
value_pred = critic(rollout.obs)
58+
next_value_pred = critic(rollout.next_obs)
59+
60+
gae_vs, gae_advantages = jax.lax.stop_gradient(
61+
compute_gae(
62+
terminated=rollout.terminated, truncated=rollout.truncated,
63+
rewards=rollout.rewards * reward_scaling,
64+
values=value_pred, next_values=next_value_pred,
65+
gae_lambda=gae_lambda, gamma=gamma,
66+
)
67+
)
68+
69+
if normalize_advantage:
70+
gae_advantages = (gae_advantages - gae_advantages.mean()) / (gae_advantages.std() + 1e-8)
71+
72+
flat_obs = rollout.obs.reshape(T * B, -1)
73+
flat_actions = rollout.actions.reshape(T * B, -1)
74+
flat_advantages = gae_advantages.reshape(T * B, 1)
75+
flat_gae_vs = gae_vs.reshape(T * B, 1)
76+
flat_truncations = rollout.truncated.reshape(T * B, 1)
77+
flat_old_log_probs = rollout.extras["log_prob"].reshape(T * B, 1)
78+
flat_aug_actions = rollout.extras["aug_action"].reshape(T * B, -1)
79+
80+
def epoch_step(carry, _):
81+
rng, actor, critic = carry
82+
rng, perm_rng = jax.random.split(rng)
83+
84+
perm = jax.random.permutation(perm_rng, T * B)
85+
total = num_minibatches * batch_size
86+
perm = perm[:total]
87+
mb_indices = perm.reshape(num_minibatches, batch_size)
88+
89+
def minibatch_step(carry, indices):
90+
rng, actor, critic = carry
91+
rng, compress_rng = jax.random.split(rng)
92+
93+
mb_obs = flat_obs[indices]
94+
mb_advantages = flat_advantages[indices]
95+
mb_gae_vs = flat_gae_vs[indices]
96+
mb_truncations = flat_truncations[indices]
97+
mb_old_log_probs = flat_old_log_probs[indices]
98+
mb_aug_actions = flat_aug_actions[indices]
99+
100+
compress_x0 = jax.random.normal(compress_rng, (batch_size, 2 * actor.a_dim))
101+
102+
def actor_loss_fn(actor_params, dropout_rng):
103+
# Log prob via inverse Jacobian (matches official impl)
104+
new_log_prob = actor.log_prob_via_inverse(
105+
mb_obs, mb_aug_actions, params=actor_params,
106+
)[:, jnp.newaxis] # (batch_size, 1)
107+
108+
# PPO clipped surrogate
109+
rho_s = jnp.exp(new_log_prob - mb_old_log_probs)
110+
surrogate1 = rho_s * mb_advantages
111+
surrogate2 = jnp.clip(rho_s, 1 - clip_epsilon, 1 + clip_epsilon) * mb_advantages
112+
policy_loss = -jnp.mean(jnp.minimum(surrogate1, surrogate2))
113+
114+
# Entropy loss (maximize entropy = minimize mean log_prob)
115+
if entropy_coeff > 0:
116+
entropy_loss = entropy_coeff * jnp.mean(new_log_prob)
117+
else:
118+
entropy_loss = 0
119+
120+
# Compress loss on fresh forward samples (L2 norm, matches official)
121+
if compress_coef > 0:
122+
x1_fresh = actor.forward(mb_obs, compress_x0, params=actor_params)
123+
z_f = x1_fresh[:, :actor.a_dim]
124+
y_f = x1_fresh[:, actor.a_dim:]
125+
compress_loss = compress_coef * jnp.mean(
126+
jnp.sqrt(jnp.sum((z_f - y_f) ** 2, axis=-1) + 1e-8)
127+
)
128+
else:
129+
compress_loss = 0
130+
131+
total_loss = policy_loss + entropy_loss + compress_loss
132+
133+
return total_loss, {
134+
"loss/policy_loss": policy_loss,
135+
"loss/entropy_loss": entropy_loss,
136+
"loss/compress_loss": compress_loss,
137+
"misc/entropy": -jnp.mean(new_log_prob),
138+
"misc/policy_ratio": jnp.mean(rho_s),
139+
"misc/clipped_ratio": jnp.mean(jnp.abs(rho_s - 1.0) > clip_epsilon),
140+
"misc/log_prob_mean": jnp.mean(new_log_prob),
141+
}
142+
143+
new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn)
144+
145+
def critic_loss_fn(critic_params, dropout_rng):
146+
v = critic.apply(
147+
{"params": critic_params}, mb_obs,
148+
training=True, rngs={"dropout": dropout_rng},
149+
)
150+
v_error = (mb_gae_vs - v) * (1 - mb_truncations)
151+
v_loss = jnp.mean(v_error ** 2)
152+
return v_loss, {
153+
"loss/value_loss": v_loss,
154+
"misc/value_mean": jnp.mean(v),
155+
}
156+
157+
new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn)
158+
159+
metrics = {**actor_metrics, **critic_metrics}
160+
return (rng, new_actor, new_critic), metrics
161+
162+
(rng, actor, critic), mb_metrics = jax.lax.scan(
163+
minibatch_step, init=(rng, actor, critic), xs=mb_indices,
164+
)
165+
return (rng, actor, critic), mb_metrics
166+
167+
(rng, actor, critic), all_metrics = jax.lax.scan(
168+
epoch_step, init=(rng, actor, critic), length=num_epochs,
169+
)
170+
171+
metrics = jax.tree.map(lambda x: x.mean(), all_metrics)
172+
metrics.update({
173+
"misc/reward_mean": rollout.rewards.mean(),
174+
"misc/obs_mean": flat_obs.mean(),
175+
"misc/obs_std": flat_obs.std(axis=0).mean(),
176+
"misc/action_l1_mean": jnp.abs(flat_actions).mean(),
177+
"misc/advantages_mean": flat_advantages.mean(),
178+
"misc/advantages_std": flat_advantages.std(axis=0).mean(),
179+
})
180+
181+
return rng, actor, critic, metrics
182+
183+
184+
# ======= Agent ========
185+
186+
class GenPOAgent(BaseAgent):
187+
"""
188+
Generative Policy Optimization (GenPO)
189+
Uses coupled Heun solver on augmented action space with exact log-probability
190+
via Jacobian determinant, combined with PPO clipping.
191+
"""
192+
name = "GenPOAgent"
193+
model_names = ["actor", "critic"]
194+
195+
def __init__(self, obs_dim, act_dim, cfg: GenPOConfig, seed):
196+
super().__init__(obs_dim, act_dim, cfg, seed)
197+
self.cfg = cfg
198+
self.rng, actor_rng, critic_rng = jax.random.split(self.rng, 3)
199+
200+
actor_activation = get_activation(cfg.flow.activation)
201+
critic_activation = get_activation(cfg.critic_activation)
202+
backbone_cls = {"mlp": MLP, "simba": Simba}[cfg.backbone_cls]
203+
204+
flow_backbone = FlowBackbone(
205+
vel_predictor=backbone_cls(
206+
hidden_dims=cfg.flow.hidden_dims,
207+
activation=actor_activation,
208+
output_dim=act_dim,
209+
),
210+
time_embedding=LearnableFourierEmbedding(output_dim=cfg.flow.time_dim),
211+
)
212+
self.actor = GenPOFlow.create(
213+
network=flow_backbone,
214+
rng=actor_rng,
215+
inputs=(
216+
jnp.ones((1, act_dim)),
217+
jnp.ones((1, 1)),
218+
jnp.ones((1, obs_dim)),
219+
),
220+
a_dim=act_dim,
221+
steps=cfg.flow.steps,
222+
mix_para=cfg.flow.mix_para,
223+
optimizer=optax.adam(learning_rate=cfg.flow.lr),
224+
clip_grad_norm=cfg.clip_grad_norm,
225+
)
226+
227+
critic_def = ScalarCritic(
228+
backbone=backbone_cls(
229+
hidden_dims=cfg.critic_hidden_dims,
230+
activation=critic_activation,
231+
),
232+
)
233+
self.critic = Model.create(
234+
critic_def, critic_rng,
235+
inputs=(jnp.ones((1, obs_dim)),),
236+
optimizer=optax.adam(learning_rate=cfg.critic_lr),
237+
clip_grad_norm=cfg.clip_grad_norm,
238+
)
239+
240+
def train_step(self, rollout: RolloutBatch, step: int) -> Metric:
241+
self.rng, self.actor, self.critic, metrics = jit_update_genpo(
242+
self.rng, self.actor, self.critic, rollout,
243+
gamma=self.cfg.gamma,
244+
gae_lambda=self.cfg.gae_lambda,
245+
clip_epsilon=self.cfg.clip_epsilon,
246+
entropy_coeff=self.cfg.entropy_coeff,
247+
compress_coef=self.cfg.compress_coef,
248+
reward_scaling=self.cfg.reward_scaling,
249+
normalize_advantage=self.cfg.normalize_advantage,
250+
num_epochs=self.cfg.num_epochs,
251+
num_minibatches=self.cfg.num_minibatches,
252+
batch_size=self.cfg.batch_size,
253+
)
254+
return metrics
255+
256+
def sample_actions(self, obs, deterministic=True, num_samples=1):
257+
assert num_samples == 1
258+
self.rng, sample_rng = jax.random.split(self.rng)
259+
260+
action, log_prob, aug_action = jit_sample_action_genpo(
261+
sample_rng, self.actor, obs, deterministic,
262+
)
263+
264+
return action, {
265+
"log_prob": log_prob[:, jnp.newaxis],
266+
"aug_action": aug_action,
267+
}

flowrl/config/online/algo/genpo.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
from dataclasses import dataclass
2+
from typing import List, Optional
3+
4+
from .base import BaseAlgoConfig
5+
6+
7+
@dataclass
8+
class GenPOFlowConfig:
9+
activation: str
10+
hidden_dims: List[int]
11+
time_dim: int
12+
steps: int
13+
mix_para: float
14+
lr: float
15+
16+
17+
@dataclass
18+
class GenPOConfig(BaseAlgoConfig):
19+
name: str
20+
backbone_cls: str
21+
critic_hidden_dims: List[int]
22+
critic_activation: str
23+
critic_lr: float
24+
gamma: float
25+
gae_lambda: float
26+
clip_epsilon: float
27+
reward_scaling: float
28+
normalize_advantage: bool
29+
num_envs: int
30+
rollout_length: int
31+
num_minibatches: int
32+
num_epochs: int
33+
batch_size: int
34+
clip_grad_norm: Optional[float]
35+
entropy_coeff: float
36+
compress_coef: float
37+
flow: GenPOFlowConfig

0 commit comments

Comments
 (0)