@@ -1668,16 +1668,9 @@ def train(
16681668 A dictionary with final metrics such as loss and accuracy
16691669 of the reward model.
16701670 """
1671- initial_comparisons = int (total_comparisons * self .initial_comparison_frac )
1672- total_comparisons -= initial_comparisons
1673-
16741671 # Compute the number of comparisons to request at each iteration in advance.
1675- vec_schedule = np .vectorize (self .query_schedule )
1676- unnormalized_probs = vec_schedule (np .linspace (0 , 1 , self .num_iterations ))
1677- probs = unnormalized_probs / np .sum (unnormalized_probs )
1678- shares = util .oric (probs * total_comparisons )
1679- schedule = [initial_comparisons ] + shares .tolist ()
1680- print (f"Query schedule: { schedule } " )
1672+ preference_query_schedule = self ._preference_gather_schedule (total_comparisons )
1673+ print (f"Query schedule: { preference_query_schedule } " )
16811674
16821675 timesteps_per_iteration , extra_timesteps = divmod (
16831676 total_timesteps ,
@@ -1686,7 +1679,7 @@ def train(
16861679 reward_loss = None
16871680 reward_accuracy = None
16881681
1689- for i , num_pairs in enumerate (schedule ):
1682+ for i , num_pairs in enumerate (preference_query_schedule ):
16901683 ##########################
16911684 # Gather new preferences #
16921685 ##########################
@@ -1749,3 +1742,13 @@ def train(
17491742 self ._iteration += 1
17501743
17511744 return {"reward_loss" : reward_loss , "reward_accuracy" : reward_accuracy }
1745+
1746+ def _preference_gather_schedule (self , total_comparisons ):
1747+ initial_comparisons = int (total_comparisons * self .initial_comparison_frac )
1748+ total_comparisons -= initial_comparisons
1749+ vec_schedule = np .vectorize (self .query_schedule )
1750+ unnormalized_probs = vec_schedule (np .linspace (0 , 1 , self .num_iterations ))
1751+ probs = unnormalized_probs / np .sum (unnormalized_probs )
1752+ shares = util .oric (probs * total_comparisons )
1753+ schedule = [initial_comparisons ] + shares .tolist ()
1754+ return schedule
0 commit comments