@@ -130,6 +130,8 @@ class SACLoss(LossModule):
130130 valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
131131 shape of the data and that masking the data results in a valid data structure. Among other things, this may
132132 not be true in MARL settings or when using RNNs. Defaults to ``False``.
133+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
134+ Defaults to ``False``.
133135
134136 Examples:
135137 >>> import torch
@@ -334,6 +336,7 @@ def __init__(
334336 separate_losses : bool = False ,
335337 reduction : str = None ,
336338 skip_done_states : bool = False ,
339+ deactivate_vmap : bool = False ,
337340 ) -> None :
338341 self ._in_keys = None
339342 self ._out_keys = None
@@ -344,6 +347,7 @@ def __init__(
344347
345348 # Actor
346349 self .delay_actor = delay_actor
350+ self .deactivate_vmap = deactivate_vmap
347351 self .convert_to_functional (
348352 actor_network ,
349353 "actor_network" ,
@@ -445,11 +449,16 @@ def __init__(
445449
446450 def _make_vmap (self ):
447451 self ._vmap_qnetworkN0 = _vmap_func (
448- self .qvalue_network , (None , 0 ), randomness = self .vmap_randomness
452+ self .qvalue_network ,
453+ (None , 0 ),
454+ randomness = self .vmap_randomness ,
455+ pseudo_vmap = self .deactivate_vmap ,
449456 )
450457 if self ._version == 1 :
451458 self ._vmap_qnetwork00 = _vmap_func (
452- self .qvalue_network , randomness = self .vmap_randomness
459+ self .qvalue_network ,
460+ randomness = self .vmap_randomness ,
461+ pseudo_vmap = self .deactivate_vmap ,
453462 )
454463
455464 @property
@@ -527,11 +536,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
527536 self ._value_estimator = TD1Estimator (
528537 ** hp ,
529538 value_network = value_net ,
539+ deactivate_vmap = self .deactivate_vmap ,
530540 )
531541 elif value_type is ValueEstimators .TD0 :
532542 self ._value_estimator = TD0Estimator (
533543 ** hp ,
534544 value_network = value_net ,
545+ deactivate_vmap = self .deactivate_vmap ,
535546 )
536547 elif value_type is ValueEstimators .GAE :
537548 raise NotImplementedError (
@@ -541,6 +552,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
541552 self ._value_estimator = TDLambdaEstimator (
542553 ** hp ,
543554 value_network = value_net ,
555+ deactivate_vmap = self .deactivate_vmap ,
544556 )
545557 else :
546558 raise NotImplementedError (f"Unknown value type { value_type } " )
@@ -673,7 +685,6 @@ def _actor_loss(
673685 raise RuntimeError (
674686 f"Losses shape mismatch: { log_prob .shape } and { min_q_logprob .shape } "
675687 )
676-
677688 return self ._alpha * log_prob - min_q_logprob , {"log_prob" : log_prob .detach ()}
678689
679690 @property
@@ -922,6 +933,8 @@ class DiscreteSACLoss(LossModule):
922933 valid, non-terminating next states. If ``True``, it is assumed that the done state can be broadcast to the
923934 shape of the data and that masking the data results in a valid data structure. Among other things, this may
924935 not be true in MARL settings or when using RNNs. Defaults to ``False``.
936+ deactivate_vmap (bool, optional): whether to deactivate vmap calls and replace them with a plain for loop.
937+ Defaults to ``False``.
925938
926939 Examples:
927940 >>> import torch
@@ -1098,6 +1111,7 @@ def __init__(
10981111 separate_losses : bool = False ,
10991112 reduction : str = None ,
11001113 skip_done_states : bool = False ,
1114+ deactivate_vmap : bool = False ,
11011115 ):
11021116 if reduction is None :
11031117 reduction = "mean"
@@ -1110,6 +1124,7 @@ def __init__(
11101124 "actor_network" ,
11111125 create_target_params = self .delay_actor ,
11121126 )
1127+ self .deactivate_vmap = deactivate_vmap
11131128 if separate_losses :
11141129 # we want to make sure there are no duplicates in the params: the
11151130 # params of critic must be refs to actor if they're shared
@@ -1184,7 +1199,10 @@ def __init__(
11841199
11851200 def _make_vmap (self ):
11861201 self ._vmap_qnetworkN0 = _vmap_func (
1187- self .qvalue_network , (None , 0 ), randomness = self .vmap_randomness
1202+ self .qvalue_network ,
1203+ (None , 0 ),
1204+ randomness = self .vmap_randomness ,
1205+ pseudo_vmap = self .deactivate_vmap ,
11881206 )
11891207
11901208 def _forward_value_estimator_keys (self , ** kwargs ) -> None :
@@ -1436,11 +1454,13 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
14361454 self ._value_estimator = TD1Estimator (
14371455 ** hp ,
14381456 value_network = None ,
1457+ deactivate_vmap = self .deactivate_vmap ,
14391458 )
14401459 elif value_type is ValueEstimators .TD0 :
14411460 self ._value_estimator = TD0Estimator (
14421461 ** hp ,
14431462 value_network = None ,
1463+ deactivate_vmap = self .deactivate_vmap ,
14441464 )
14451465 elif value_type is ValueEstimators .GAE :
14461466 raise NotImplementedError (
@@ -1450,6 +1470,7 @@ def make_value_estimator(self, value_type: ValueEstimators = None, **hyperparams
14501470 self ._value_estimator = TDLambdaEstimator (
14511471 ** hp ,
14521472 value_network = None ,
1473+ deactivate_vmap = self .deactivate_vmap ,
14531474 )
14541475 else :
14551476 raise NotImplementedError (f"Unknown value type { value_type } " )
0 commit comments