@@ -74,24 +74,43 @@ def jit_sample_actions(
7474 chain = jnp .transpose (
7575 jnp .concatenate ([history [0 ], action [jnp .newaxis ]], axis = 0 ), (1 , 0 , 2 ))
7676 step_lps = jit_compute_chain_log_probs (actor , obs , chain , steps , min_logprob_std )
77- log_prob = step_lps .mean (axis = - 1 , keepdims = True )
77+ log_prob = step_lps .mean (axis = - 1 , keepdims = True ) # not used in DPPO training
7878 return rng , action , chain , log_prob
7979
8080
8181@partial (jax .jit , static_argnames = (
82- "gamma" , "gae_lambda" , "gamma_denoising" ,
83- "clip_epsilon" , "clip_epsilon_base" , "clip_epsilon_rate" ,
84- "reward_scaling" , "normalize_advantage" ,
85- "num_epochs" , "num_minibatches" , "batch_size" ,
86- "denoising_steps" , "min_logprob_std" ,
82+ "gamma" ,
83+ "gae_lambda" ,
84+ "gamma_denoising" ,
85+ "clip_epsilon" ,
86+ "clip_epsilon_base" ,
87+ "clip_epsilon_rate" ,
88+ "reward_scaling" ,
89+ "normalize_advantage" ,
90+ "num_epochs" ,
91+ "num_minibatches" ,
92+ "batch_size" ,
93+ "denoising_steps" ,
94+ "min_logprob_std" ,
8795))
8896def jit_update_dppo (
89- rng : PRNGKey , actor : ContinuousDDPM , critic : Model , rollout : RolloutBatch ,
90- gamma : float , gae_lambda : float , gamma_denoising : float ,
91- clip_epsilon : float , clip_epsilon_base : float , clip_epsilon_rate : float ,
92- reward_scaling : float , normalize_advantage : bool ,
93- num_epochs : int , num_minibatches : int , batch_size : int ,
94- denoising_steps : int , min_logprob_std : float ,
97+ rng : PRNGKey ,
98+ actor : ContinuousDDPM ,
99+ critic : Model ,
100+ rollout : RolloutBatch ,
101+ gamma : float ,
102+ gae_lambda : float ,
103+ gamma_denoising : float ,
104+ clip_epsilon : float ,
105+ clip_epsilon_base : float ,
106+ clip_epsilon_rate : float ,
107+ reward_scaling : float ,
108+ normalize_advantage : bool ,
109+ num_epochs : int ,
110+ num_minibatches : int ,
111+ batch_size : int ,
112+ denoising_steps : int ,
113+ min_logprob_std : float ,
95114):
96115 T , B = rollout .rewards .shape [:2 ]
97116 K = denoising_steps
@@ -161,10 +180,14 @@ def actor_loss_fn(actor_params, dropout_rng):
161180 new_actor , actor_info = actor .apply_gradient (actor_loss_fn )
162181
163182 def critic_loss_fn (critic_params , dropout_rng ):
164- v = critic .apply ({"params" : critic_params }, mb_obs ,
165- training = True , rngs = {"dropout" : dropout_rng })
183+ v = critic .apply ({"params" : critic_params },
184+ mb_obs ,
185+ training = True ,
186+ rngs = {"dropout" : dropout_rng })
166187 v_loss = jnp .mean ((mb_vs - v ) ** 2 )
167- return v_loss , {"loss/value_loss" : v_loss , "misc/value_mean" : jnp .mean (v )}
188+ return v_loss , {
189+ "loss/value_loss" : v_loss ,
190+ "misc/value_mean" : jnp .mean (v )}
168191
169192 new_critic , critic_info = critic .apply_gradient (critic_loss_fn )
170193 return (rng , new_actor , new_critic ), {** actor_info , ** critic_info }
0 commit comments