@@ -1506,7 +1506,7 @@ def __init__(
15061506 transition_oversampling : float = 1 ,
15071507 initial_comparison_frac : float = 0.1 ,
15081508 initial_epoch_multiplier : float = 200.0 ,
1509- initial_agent_pretrain_frac : float = 0.05 ,
1509+ unsupervised_agent_pretrain_frac : float = 0.05 ,
15101510 custom_logger : Optional [imit_logger .HierarchicalLogger ] = None ,
15111511 allow_variable_horizon : bool = False ,
15121512 rng : Optional [np .random .Generator ] = None ,
@@ -1556,7 +1556,7 @@ def __init__(
15561556 initial_epoch_multiplier: before agent training begins, train the reward
15571557 model for this many more epochs than usual (on fragments sampled from a
15581558 random agent).
1559- initial_agent_pretrain_frac : fraction of total_timesteps for which the
1559+ unsupervised_agent_pretrain_frac : fraction of total_timesteps for which the
15601560 agent will be trained without preference gathering (and reward model
15611561 training)
15621562 custom_logger: Where to log to; if None (default), creates a new logger.
@@ -1657,7 +1657,7 @@ def __init__(
16571657 self .fragment_length = fragment_length
16581658 self .initial_comparison_frac = initial_comparison_frac
16591659 self .initial_epoch_multiplier = initial_epoch_multiplier
1660- self .initial_agent_pretrain_frac = initial_agent_pretrain_frac
1660+ self .unsupervised_agent_pretrain_frac = unsupervised_agent_pretrain_frac
16611661 self .num_iterations = num_iterations
16621662 self .transition_oversampling = transition_oversampling
16631663 if callable (query_schedule ):
@@ -1691,7 +1691,7 @@ def train(
16911691 print (f"Query schedule: { preference_query_schedule } " )
16921692
16931693 (
1694- agent_pretrain_timesteps ,
1694+ unsupervised_pretrain_timesteps ,
16951695 timesteps_per_iteration ,
16961696 extra_timesteps ,
16971697 ) = self ._compute_timesteps (total_timesteps )
@@ -1703,9 +1703,9 @@ def train(
17031703 ###################################################
17041704 with self .logger .accumulate_means ("agent" ):
17051705 self .logger .log (
1706- f"Pre-training agent for { agent_pretrain_timesteps } timesteps"
1706+ f"Pre-training agent for { unsupervised_pretrain_timesteps } timesteps"
17071707 )
1708- self .trajectory_generator .unsupervised_pretrain (agent_pretrain_timesteps )
1708+ self .trajectory_generator .unsupervised_pretrain (unsupervised_pretrain_timesteps )
17091709
17101710 for i , num_pairs in enumerate (preference_query_schedule ):
17111711 ##########################
@@ -1782,11 +1782,11 @@ def _preference_gather_schedule(self, total_comparisons):
17821782 return schedule
17831783
17841784 def _compute_timesteps (self , total_timesteps : int ) -> Tuple [int , int , int ]:
1785- agent_pretrain_timesteps = int (
1786- total_timesteps * self .initial_agent_pretrain_frac
1785+ unsupervised_pretrain_timesteps = int (
1786+ total_timesteps * self .unsupervised_agent_pretrain_frac
17871787 )
17881788 timesteps_per_iteration , extra_timesteps = divmod (
1789- total_timesteps - agent_pretrain_timesteps ,
1789+ total_timesteps - unsupervised_pretrain_timesteps ,
17901790 self .num_iterations ,
17911791 )
1792- return agent_pretrain_timesteps , timesteps_per_iteration , extra_timesteps
1792+ return unsupervised_pretrain_timesteps , timesteps_per_iteration , extra_timesteps
0 commit comments