@@ -1493,6 +1493,7 @@ def __init__(
14931493 transition_oversampling : float = 1 ,
14941494 initial_comparison_frac : float = 0.1 ,
14951495 initial_epoch_multiplier : float = 200.0 ,
1496+ initial_agent_pretrain_frac : float = 0.01 ,
14961497 custom_logger : Optional [imit_logger .HierarchicalLogger ] = None ,
14971498 allow_variable_horizon : bool = False ,
14981499 rng : Optional [np .random .Generator ] = None ,
@@ -1542,6 +1543,9 @@ def __init__(
15421543 initial_epoch_multiplier: before agent training begins, train the reward
15431544 model for this many more epochs than usual (on fragments sampled from a
15441545 random agent).
1546+ initial_agent_pretrain_frac: fraction of total_timesteps for which the
1547+ agent will be trained without preference gathering (and reward model
1548+ training)
15451549 custom_logger: Where to log to; if None (default), creates a new logger.
15461550 allow_variable_horizon: If False (default), algorithm will raise an
15471551 exception if it detects trajectories of different length during
@@ -1640,6 +1644,7 @@ def __init__(
16401644 self .fragment_length = fragment_length
16411645 self .initial_comparison_frac = initial_comparison_frac
16421646 self .initial_epoch_multiplier = initial_epoch_multiplier
1647+ self .initial_agent_pretrain_frac = initial_agent_pretrain_frac
16431648 self .num_iterations = num_iterations
16441649 self .transition_oversampling = transition_oversampling
16451650 if callable (query_schedule ):
@@ -1672,10 +1677,11 @@ def train(
16721677 preference_query_schedule = self ._preference_gather_schedule (total_comparisons )
16731678 print (f"Query schedule: { preference_query_schedule } " )
16741679
1675- timesteps_per_iteration , extra_timesteps = divmod (
1676- total_timesteps ,
1677- self .num_iterations ,
1678- )
1680+ (
1681+ agent_pretrain_timesteps ,
1682+ timesteps_per_iteration ,
1683+ extra_timesteps ,
1684+ ) = self ._compute_timesteps (total_timesteps )
16791685 reward_loss = None
16801686 reward_accuracy = None
16811687
@@ -1752,3 +1758,13 @@ def _preference_gather_schedule(self, total_comparisons):
17521758 shares = util .oric (probs * total_comparisons )
17531759 schedule = [initial_comparisons ] + shares .tolist ()
17541760 return schedule
1761+
1762+ def _compute_timesteps (self , total_timesteps : int ) -> Tuple [int , int , int ]:
1763+ agent_pretrain_timesteps = int (
1764+ total_timesteps * self .initial_agent_pretrain_frac
1765+ )
1766+ timesteps_per_iteration , extra_timesteps = divmod (
1767+ total_timesteps - agent_pretrain_timesteps ,
1768+ self .num_iterations ,
1769+ )
1770+ return agent_pretrain_timesteps , timesteps_per_iteration , extra_timesteps
0 commit comments