Skip to content

Commit 16e9114

Browse files
authored
fix: [LAS] full predictions regardless of learn/predict path (#4273)
1 parent 0406c0f commit 16e9114

File tree

4 files changed

+48
-14
lines changed

4 files changed

+48
-14
lines changed

test/unit_test/cb_las_spanner_test.cc

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,4 +871,33 @@ BOOST_AUTO_TEST_CASE(check_singular_value_sum_diff_for_diff_ranks_is_small)
871871
VW::finish(vw);
872872
}
873873

874+
BOOST_AUTO_TEST_CASE(check_learn_returns_correct_predictions)
875+
{
876+
auto d = 2;
877+
auto& vw = *VW::initialize(
878+
"--cb_explore_adf --large_action_space --max_actions " + std::to_string(d) + " --quiet --random_seed 12", nullptr,
879+
false, nullptr, nullptr);
880+
881+
VW::multi_ex examples;
882+
examples.push_back(VW::read_example(vw, "| 1:0.1 2:0.12 3:0.13 b200:2 c500:9"));
883+
examples.push_back(VW::read_example(vw, "| a_1:0.1 a_2:0.25 a_3:0.12 a100:1 a200:0.1"));
884+
examples.push_back(VW::read_example(vw, "| a_1:0.2 a_2:0.32 a_3:0.15 a100:0.2 a200:0.2"));
885+
examples.push_back(VW::read_example(vw, "| a_1:0.5 a_2:0.89 a_3:0.42 a100:1.4 a200:0.5"));
886+
examples.push_back(VW::read_example(vw, "| a_4:0.8 a_5:0.32 a_6:0.15 d1:0.2 d10: 0.2"));
887+
examples.push_back(VW::read_example(vw, "| a_7 a_8 a_9 v1:0.99"));
888+
examples.push_back(VW::read_example(vw, "| a_10 a_11 a_12"));
889+
examples.push_back(VW::read_example(vw, "| a_13 a_14 a_15"));
890+
examples.push_back(VW::read_example(vw, "| a_16 a_17 a_18:0.2"));
891+
892+
vw.learn(examples);
893+
894+
const auto& preds = examples[0]->pred.a_s;
895+
896+
BOOST_CHECK_EQUAL(preds.size(), examples.size());
897+
898+
vw.finish_example(examples);
899+
900+
VW::finish(vw);
901+
}
902+
874903
BOOST_AUTO_TEST_SUITE_END()

vowpalwabbit/core/include/vw/core/reductions/cb/cb_actions_mask.h

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,10 @@ namespace reductions
1313
{
1414
class cb_actions_mask
1515
{
16-
public:
1716
// this reduction is used to get the actions mask from VW::actions_mask::reduction_features and apply it to the
1817
// outcoming predictions
19-
void learn(VW::LEARNER::multi_learner& base, multi_ex& examples);
20-
void predict(VW::LEARNER::multi_learner& base, multi_ex& examples);
18+
public:
19+
void update_predictions(multi_ex& examples, size_t initial_action_size);
2120

2221
private:
2322
template <bool is_learn>

vowpalwabbit/core/src/reductions/cb/cb_actions_mask.cc

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,12 @@
88
#include "vw/core/action_score.h"
99
#include "vw/core/global_data.h"
1010
#include "vw/core/learner.h"
11+
#include "vw/core/reductions/cb/cb_adf.h"
1112
#include "vw/core/setup_base.h"
1213
#include "vw/core/vw.h"
1314

14-
void VW::reductions::cb_actions_mask::learn(VW::LEARNER::multi_learner& base, multi_ex& examples)
15+
void VW::reductions::cb_actions_mask::update_predictions(multi_ex& examples, size_t initial_action_size)
1516
{
16-
base.learn(examples);
17-
}
18-
19-
void VW::reductions::cb_actions_mask::predict(VW::LEARNER::multi_learner& base, multi_ex& examples)
20-
{
21-
auto initial_action_size = examples.size();
22-
base.predict(examples);
23-
2417
auto& preds = examples[0]->pred.a_s;
2518
std::vector<bool> actions_present(initial_action_size);
2619
for (const auto& action_score : preds) { actions_present[action_score.action] = true; }
@@ -34,10 +27,20 @@ void VW::reductions::cb_actions_mask::predict(VW::LEARNER::multi_learner& base,
3427
template <bool is_learn>
3528
void learn_or_predict(VW::reductions::cb_actions_mask& data, VW::LEARNER::multi_learner& base, VW::multi_ex& examples)
3629
{
37-
if (is_learn) { data.learn(base, examples); }
30+
auto initial_action_size = examples.size();
31+
if (is_learn)
32+
{
33+
base.learn(examples);
34+
35+
VW::example* label_example = CB_ADF::test_adf_sequence(examples);
36+
37+
if (base.learn_returns_prediction || label_example == nullptr)
38+
{ data.update_predictions(examples, initial_action_size); }
39+
}
3840
else
3941
{
40-
data.predict(base, examples);
42+
base.predict(examples);
43+
data.update_predictions(examples, initial_action_size);
4144
}
4245
}
4346

@@ -56,6 +59,7 @@ VW::LEARNER::base_learner* VW::reductions::cb_actions_mask_setup(VW::setup_base_
5659
.set_output_label_type(VW::label_type_t::CB)
5760
.set_input_prediction_type(VW::prediction_type_t::ACTION_SCORES)
5861
.set_output_prediction_type(VW::prediction_type_t::ACTION_PROBS)
62+
.set_learn_returns_prediction(base->learn_returns_prediction)
5963
.build();
6064
return VW::LEARNER::make_base(*l);
6165
}

vowpalwabbit/core/src/reductions/cb/cb_explore_adf_large_action_space.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,7 @@ void cb_explore_adf_large_action_space<randomized_svd_impl, spanner_impl>::predi
221221
if (is_learn)
222222
{
223223
base.learn(examples);
224+
if (base.learn_returns_prediction) { update_example_prediction(examples); }
224225
++_counter;
225226
}
226227
else
@@ -323,6 +324,7 @@ VW::LEARNER::base_learner* make_las_with_impl(VW::setup_base_i& stack_builder, V
323324
.set_print_example(explore_type::print_multiline_example)
324325
.set_persist_metrics(explore_type::persist_metrics)
325326
.set_save_load(explore_type::save_load)
327+
.set_learn_returns_prediction(base->learn_returns_prediction)
326328
.build(&all.logger);
327329
return make_base(*l);
328330
}

0 commit comments

Comments
 (0)