Skip to content

Commit c2bc9dc

Browse files
author
Jan Michelfeit
committed
#625 extract _preference_feedback_schedule()
1 parent 567e980 commit c2bc9dc

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

src/imitation/algorithms/preference_comparisons.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1668,16 +1668,9 @@ def train(
16681668
A dictionary with final metrics such as loss and accuracy
16691669
of the reward model.
16701670
"""
1671-
initial_comparisons = int(total_comparisons * self.initial_comparison_frac)
1672-
total_comparisons -= initial_comparisons
1673-
16741671
# Compute the number of comparisons to request at each iteration in advance.
1675-
vec_schedule = np.vectorize(self.query_schedule)
1676-
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
1677-
probs = unnormalized_probs / np.sum(unnormalized_probs)
1678-
shares = util.oric(probs * total_comparisons)
1679-
schedule = [initial_comparisons] + shares.tolist()
1680-
print(f"Query schedule: {schedule}")
1672+
preference_query_schedule = self._preference_gather_schedule(total_comparisons)
1673+
print(f"Query schedule: {preference_query_schedule}")
16811674

16821675
timesteps_per_iteration, extra_timesteps = divmod(
16831676
total_timesteps,
@@ -1686,7 +1679,7 @@ def train(
16861679
reward_loss = None
16871680
reward_accuracy = None
16881681

1689-
for i, num_pairs in enumerate(schedule):
1682+
for i, num_pairs in enumerate(preference_query_schedule):
16901683
##########################
16911684
# Gather new preferences #
16921685
##########################
@@ -1749,3 +1742,13 @@ def train(
17491742
self._iteration += 1
17501743

17511744
return {"reward_loss": reward_loss, "reward_accuracy": reward_accuracy}
1745+
1746+
def _preference_gather_schedule(self, total_comparisons):
1747+
initial_comparisons = int(total_comparisons * self.initial_comparison_frac)
1748+
total_comparisons -= initial_comparisons
1749+
vec_schedule = np.vectorize(self.query_schedule)
1750+
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
1751+
probs = unnormalized_probs / np.sum(unnormalized_probs)
1752+
shares = util.oric(probs * total_comparisons)
1753+
schedule = [initial_comparisons] + shares.tolist()
1754+
return schedule

0 commit comments

Comments
 (0)