@@ -75,6 +75,19 @@ def sample(self, steps: int) -> Sequence[TrajectoryWithRew]:
7575 be the environment rewards, not ones from a reward model).
7676 """ # noqa: DAR202
7777
78+ def unsupervised_pretrain (self , steps : int , ** kwargs : Any ) -> None :
79+ """Pre-train an agent if the trajectory generator uses one that
80+ needs pre-training.
81+
82+ By default, this method does nothing and doesn't need
83+ to be overridden in subclasses that don't require pre-training.
84+
85+ Args:
86+ steps: number of environment steps to train for.
87+ **kwargs: additional keyword arguments to pass on to
88+ the training procedure.
89+ """
90+
7891 def train (self , steps : int , ** kwargs : Any ) -> None :
7992 """Train an agent if the trajectory generator uses one.
8093
@@ -1493,7 +1506,7 @@ def __init__(
14931506 transition_oversampling : float = 1 ,
14941507 initial_comparison_frac : float = 0.1 ,
14951508 initial_epoch_multiplier : float = 200.0 ,
1496- initial_agent_pretrain_frac : float = 0.01 ,
1509+ initial_agent_pretrain_frac : float = 0.05 ,
14971510 custom_logger : Optional [imit_logger .HierarchicalLogger ] = None ,
14981511 allow_variable_horizon : bool = False ,
14991512 rng : Optional [np .random .Generator ] = None ,
@@ -1685,6 +1698,15 @@ def train(
16851698 reward_loss = None
16861699 reward_accuracy = None
16871700
1701+ ###################################################
1702+ # Pre-training agent before gathering preferences #
1703+ ###################################################
1704+ with self .logger .accumulate_means ("agent" ):
1705+ self .logger .log (
1706+ f"Pre-training agent for { agent_pretrain_timesteps } timesteps"
1707+ )
1708+ self .trajectory_generator .unsupervised_pretrain (agent_pretrain_timesteps )
1709+
16881710 for i , num_pairs in enumerate (preference_query_schedule ):
16891711 ##########################
16901712 # Gather new preferences #
0 commit comments