@@ -144,6 +144,7 @@ void do_actual_learning(explore_eval& data, multi_learner& base, VW::multi_ex& e
144144 data.action_label = std::move (label_example->l .cb );
145145 label_example->l .cb = std::move (data.empty_label );
146146 }
147+
147148 multiline_learn_or_predict<false >(base, ec_seq, data.offset );
148149
149150 if (label_example != nullptr ) // restore label
@@ -159,11 +160,18 @@ void do_actual_learning(explore_eval& data, multi_learner& base, VW::multi_ex& e
159160 VW::action_scores& a_s = ec_seq[0 ]->pred .a_s ;
160161
161162 float action_probability = 0 ;
163+ bool action_found = false ;
162164 for (size_t i = 0 ; i < a_s.size (); i++)
163165 {
164- if (data.known_cost .action == a_s[i].action ) { action_probability = a_s[i].score ; }
166+ if (data.known_cost .action == a_s[i].action )
167+ {
168+ action_probability = a_s[i].score ;
169+ action_found = true ;
170+ }
165171 }
166172
173+ if (!action_found) { return ; }
174+
167175 float threshold = action_probability / data.known_cost .probability ;
168176
169177 if (!data.fixed_multiplier ) { data.multiplier = std::min (data.multiplier , 1 / threshold); }
@@ -183,15 +191,18 @@ void do_actual_learning(explore_eval& data, multi_learner& base, VW::multi_ex& e
183191 { ec_found = ec; }
184192 if (threshold > 1 ) { ec->weight *= threshold; }
185193 }
194+
186195 ec_found->l .cb .costs [0 ].probability = action_probability;
187196
188197 multiline_learn_or_predict<true >(base, ec_seq, data.offset );
189198
199+ // restore logged example
190200 if (threshold > 1 )
191201 {
192202 float inv_threshold = 1 .f / threshold;
193203 for (auto & ec : ec_seq) { ec->weight *= inv_threshold; }
194204 }
205+
195206 ec_found->l .cb .costs [0 ].probability = data.known_cost .probability ;
196207 data.update_count ++;
197208 }
0 commit comments