Skip to content

Commit 0406c0f

Browse files
authored
fix: explore_eval don't learn if logged action not in predicted actions (#4262)
1 parent e3685a0 commit 0406c0f

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

vowpalwabbit/core/src/reductions/explore_eval.cc

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,6 +144,7 @@ void do_actual_learning(explore_eval& data, multi_learner& base, VW::multi_ex& e
144144
data.action_label = std::move(label_example->l.cb);
145145
label_example->l.cb = std::move(data.empty_label);
146146
}
147+
147148
multiline_learn_or_predict<false>(base, ec_seq, data.offset);
148149

149150
if (label_example != nullptr) // restore label
@@ -159,11 +160,18 @@ void do_actual_learning(explore_eval& data, multi_learner& base, VW::multi_ex& e
159160
VW::action_scores& a_s = ec_seq[0]->pred.a_s;
160161

161162
float action_probability = 0;
163+
bool action_found = false;
162164
for (size_t i = 0; i < a_s.size(); i++)
163165
{
164-
if (data.known_cost.action == a_s[i].action) { action_probability = a_s[i].score; }
166+
if (data.known_cost.action == a_s[i].action)
167+
{
168+
action_probability = a_s[i].score;
169+
action_found = true;
170+
}
165171
}
166172

173+
if (!action_found) { return; }
174+
167175
float threshold = action_probability / data.known_cost.probability;
168176

169177
if (!data.fixed_multiplier) { data.multiplier = std::min(data.multiplier, 1 / threshold); }
@@ -183,15 +191,18 @@ void do_actual_learning(explore_eval& data, multi_learner& base, VW::multi_ex& e
183191
{ ec_found = ec; }
184192
if (threshold > 1) { ec->weight *= threshold; }
185193
}
194+
186195
ec_found->l.cb.costs[0].probability = action_probability;
187196

188197
multiline_learn_or_predict<true>(base, ec_seq, data.offset);
189198

199+
// restore logged example
190200
if (threshold > 1)
191201
{
192202
float inv_threshold = 1.f / threshold;
193203
for (auto& ec : ec_seq) { ec->weight *= inv_threshold; }
194204
}
205+
195206
ec_found->l.cb.costs[0].probability = data.known_cost.probability;
196207
data.update_count++;
197208
}

0 commit comments

Comments
 (0)