@@ -68,6 +68,7 @@ def train_preference_comparisons(
6868 fragment_length : int ,
6969 transition_oversampling : float ,
7070 initial_comparison_frac : float ,
71+ initial_epoch_multiplier : float ,
7172 exploration_frac : float ,
7273 trajectory_path : Optional [str ],
7374 trajectory_generator_kwargs : Mapping [str , Any ],
@@ -106,6 +107,9 @@ def train_preference_comparisons(
106107 sampled before the rest of training begins (using the randomly initialized
107108 agent). This can be used to pretrain the reward model before the agent
108109 is trained on the learned reward.
110+ initial_epoch_multiplier: before agent training begins, train the reward
111+ model for this many more epochs than usual (on fragments sampled from a
112+ random agent).
109113 exploration_frac: fraction of trajectory samples that will be created using
110114 partially random actions, rather than the current policy. Might be helpful
111115 if the learned policy explores too little and gets stuck with a wrong
@@ -258,6 +262,7 @@ def train_preference_comparisons(
258262 fragment_length = fragment_length ,
259263 transition_oversampling = transition_oversampling ,
260264 initial_comparison_frac = initial_comparison_frac ,
265+ initial_epoch_multiplier = initial_epoch_multiplier ,
261266 custom_logger = custom_logger ,
262267 allow_variable_horizon = allow_variable_horizon ,
263268 query_schedule = query_schedule ,
0 commit comments