Skip to content

Commit 55aa6eb

Browse files
committed
Ensure that PC does at least one comparison per iteration.
1 parent d7a7da8 commit 55aa6eb

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

src/imitation/algorithms/preference_comparisons.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1678,6 +1678,8 @@ def train(
16781678
unnormalized_probs = vec_schedule(np.linspace(0, 1, self.num_iterations))
16791679
probs = unnormalized_probs / np.sum(unnormalized_probs)
16801680
shares = util.oric(probs * total_comparisons)
1681+
shares[shares <= 0] = 1 # ensure we at least request one comparison per iteration
1682+
16811683
schedule = [initial_comparisons] + shares.tolist()
16821684
print(f"Query schedule: {schedule}")
16831685

0 commit comments

Comments
 (0)