Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 55 additions & 5 deletions runtime/core/evalue.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,26 @@ class BoxedEvalueList {
* unwrapped vals.
*/
BoxedEvalueList(EValue** wrapped_vals, T* unwrapped_vals, int size)
: wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {}
: wrapped_vals_(checkWrappedVals(wrapped_vals, size), size),
unwrapped_vals_(checkUnwrappedVals(unwrapped_vals)) {}

/*
* Constructs and returns the list of T specified by the EValue pointers
*/
executorch::aten::ArrayRef<T> get() const;

private:
static EValue** checkWrappedVals(EValue** wrapped_vals, int size) {
ET_CHECK_MSG(wrapped_vals != nullptr, "wrapped_vals cannot be null");
ET_CHECK_MSG(size >= 0, "size cannot be negative");
return wrapped_vals;
}

static T* checkUnwrappedVals(T* unwrapped_vals) {
ET_CHECK_MSG(unwrapped_vals != nullptr, "unwrapped_vals cannot be null");
return unwrapped_vals;
}

// Source of truth for the list
executorch::aten::ArrayRef<EValue*> wrapped_vals_;
// Same size as wrapped_vals
Expand Down Expand Up @@ -280,6 +293,7 @@ struct EValue {

/****** String Type ******/
/*implicit*/ EValue(executorch::aten::ArrayRef<char>* s) : tag(Tag::String) {
ET_CHECK_MSG(s != nullptr, "ArrayRef<char> pointer cannot be null");
payload.copyable_union.as_string_ptr = s;
}

Expand All @@ -289,13 +303,18 @@ struct EValue {

std::string_view toString() const {
ET_CHECK_MSG(isString(), "EValue is not a String.");
ET_CHECK_MSG(
payload.copyable_union.as_string_ptr != nullptr,
"EValue string pointer is null.");
return std::string_view(
payload.copyable_union.as_string_ptr->data(),
payload.copyable_union.as_string_ptr->size());
}

/****** Int List Type ******/
/*implicit*/ EValue(BoxedEvalueList<int64_t>* i) : tag(Tag::ListInt) {
ET_CHECK_MSG(
i != nullptr, "BoxedEvalueList<int64_t> pointer cannot be null");
payload.copyable_union.as_int_list_ptr = i;
}

Expand All @@ -305,12 +324,16 @@ struct EValue {

executorch::aten::ArrayRef<int64_t> toIntList() const {
ET_CHECK_MSG(isIntList(), "EValue is not an Int List.");
ET_CHECK_MSG(
payload.copyable_union.as_int_list_ptr != nullptr,
"EValue int list pointer is null.");
return (payload.copyable_union.as_int_list_ptr)->get();
}

/****** Bool List Type ******/
/*implicit*/ EValue(executorch::aten::ArrayRef<bool>* b)
: tag(Tag::ListBool) {
ET_CHECK_MSG(b != nullptr, "ArrayRef<bool> pointer cannot be null");
payload.copyable_union.as_bool_list_ptr = b;
}

Expand All @@ -320,12 +343,16 @@ struct EValue {

executorch::aten::ArrayRef<bool> toBoolList() const {
ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List.");
ET_CHECK_MSG(
payload.copyable_union.as_bool_list_ptr != nullptr,
"EValue bool list pointer is null.");
return *(payload.copyable_union.as_bool_list_ptr);
}

/****** Double List Type ******/
/*implicit*/ EValue(executorch::aten::ArrayRef<double>* d)
: tag(Tag::ListDouble) {
ET_CHECK_MSG(d != nullptr, "ArrayRef<double> pointer cannot be null");
payload.copyable_union.as_double_list_ptr = d;
}

Expand All @@ -335,12 +362,17 @@ struct EValue {

executorch::aten::ArrayRef<double> toDoubleList() const {
ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List.");
ET_CHECK_MSG(
payload.copyable_union.as_double_list_ptr != nullptr,
"EValue double list pointer is null.");
return *(payload.copyable_union.as_double_list_ptr);
}

/****** Tensor List Type ******/
/*implicit*/ EValue(BoxedEvalueList<executorch::aten::Tensor>* t)
: tag(Tag::ListTensor) {
ET_CHECK_MSG(
t != nullptr, "BoxedEvalueList<Tensor> pointer cannot be null");
payload.copyable_union.as_tensor_list_ptr = t;
}

Expand All @@ -350,13 +382,19 @@ struct EValue {

executorch::aten::ArrayRef<executorch::aten::Tensor> toTensorList() const {
ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List.");
ET_CHECK_MSG(
payload.copyable_union.as_tensor_list_ptr != nullptr,
"EValue tensor list pointer is null.");
return payload.copyable_union.as_tensor_list_ptr->get();
}

/****** List Optional Tensor Type ******/
/*implicit*/ EValue(
BoxedEvalueList<std::optional<executorch::aten::Tensor>>* t)
: tag(Tag::ListOptionalTensor) {
ET_CHECK_MSG(
t != nullptr,
"BoxedEvalueList<optional<Tensor>> pointer cannot be null");
payload.copyable_union.as_list_optional_tensor_ptr = t;
}

Expand All @@ -366,6 +404,11 @@ struct EValue {

executorch::aten::ArrayRef<std::optional<executorch::aten::Tensor>>
toListOptionalTensor() const {
ET_CHECK_MSG(
isListOptionalTensor(), "EValue is not a List Optional Tensor.");
ET_CHECK_MSG(
payload.copyable_union.as_list_optional_tensor_ptr != nullptr,
"EValue list optional tensor pointer is null.");
return payload.copyable_union.as_list_optional_tensor_ptr->get();
}

Expand Down Expand Up @@ -445,12 +488,19 @@ struct EValue {
// minor performance bump for a code maintainability hit
if (isTensor()) {
payload.as_tensor.~Tensor();
} else if (isTensorList()) {
for (auto& tensor : toTensorList()) {
} else if (
isTensorList() &&
payload.copyable_union.as_tensor_list_ptr != nullptr) {
// for (auto& tensor : toTensorList()) {
for (auto& tensor : payload.copyable_union.as_tensor_list_ptr->get()) {
tensor.~Tensor();
}
} else if (isListOptionalTensor()) {
for (auto& optional_tensor : toListOptionalTensor()) {
} else if (
isListOptionalTensor() &&
payload.copyable_union.as_list_optional_tensor_ptr != nullptr) {
// for (auto& optional_tensor : toListOptionalTensor()) {
for (auto& optional_tensor :
payload.copyable_union.as_list_optional_tensor_ptr->get()) {
optional_tensor.~optional();
}
}
Expand Down
127 changes: 127 additions & 0 deletions runtime/core/test/evalue_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,130 @@ TEST_F(EValueTest, ConstructFromNullPtrAborts) {

ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, "");
}

TEST_F(EValueTest, StringConstructorNullCheck) {
executorch::aten::ArrayRef<char>* null_string_ptr = nullptr;
ET_EXPECT_DEATH({ EValue evalue(null_string_ptr); }, "");
}

TEST_F(EValueTest, BoolListConstructorNullCheck) {
executorch::aten::ArrayRef<bool>* null_bool_list_ptr = nullptr;
ET_EXPECT_DEATH({ EValue evalue(null_bool_list_ptr); }, "");
}

TEST_F(EValueTest, DoubleListConstructorNullCheck) {
executorch::aten::ArrayRef<double>* null_double_list_ptr = nullptr;
ET_EXPECT_DEATH({ EValue evalue(null_double_list_ptr); }, "");
}

TEST_F(EValueTest, IntListConstructorNullCheck) {
BoxedEvalueList<int64_t>* null_int_list_ptr = nullptr;
ET_EXPECT_DEATH({ EValue evalue(null_int_list_ptr); }, "");
}

TEST_F(EValueTest, TensorListConstructorNullCheck) {
BoxedEvalueList<executorch::aten::Tensor>* null_tensor_list_ptr = nullptr;
ET_EXPECT_DEATH({ EValue evalue(null_tensor_list_ptr); }, "");
}

TEST_F(EValueTest, OptionalTensorListConstructorNullCheck) {
BoxedEvalueList<std::optional<executorch::aten::Tensor>>*
null_optional_tensor_list_ptr = nullptr;
ET_EXPECT_DEATH({ EValue evalue(null_optional_tensor_list_ptr); }, "");
}

TEST_F(EValueTest, BoxedEvalueListConstructorNullChecks) {
std::array<int64_t, 3> storage = {0, 0, 0};
std::array<EValue, 3> values = {
EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)};
std::array<EValue*, 3> values_p = {&values[0], &values[1], &values[2]};

// Test null wrapped_vals
ET_EXPECT_DEATH(
{ BoxedEvalueList<int64_t> list(nullptr, storage.data(), 3); }, "");

// Test null unwrapped_vals
ET_EXPECT_DEATH(
{ BoxedEvalueList<int64_t> list(values_p.data(), nullptr, 3); }, "");

// Test negative size
ET_EXPECT_DEATH(
{ BoxedEvalueList<int64_t> list(values_p.data(), storage.data(), -1); },
"");
}

TEST_F(EValueTest, toListOptionalTensorTypeCheck) {
// Create an EValue that's not a ListOptionalTensor
EValue e((int64_t)42);
EXPECT_TRUE(e.isInt());
EXPECT_FALSE(e.isListOptionalTensor());

// Should fail type check
ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "");
}

TEST_F(EValueTest, toStringNullPointerCheck) {
// Create an EValue with String tag but null pointer
EValue e;
e.tag = Tag::String;
e.payload.copyable_union.as_string_ptr = nullptr;

// Should pass isString() check but fail null pointer check
EXPECT_TRUE(e.isString());
ET_EXPECT_DEATH({ e.toString(); }, "");
}

TEST_F(EValueTest, toIntListNullPointerCheck) {
// Create an EValue with ListInt tag but null pointer
EValue e;
e.tag = Tag::ListInt;
e.payload.copyable_union.as_int_list_ptr = nullptr;

// Should pass isIntList() check but fail null pointer check
EXPECT_TRUE(e.isIntList());
ET_EXPECT_DEATH({ e.toIntList(); }, "");
}

TEST_F(EValueTest, toBoolListNullPointerCheck) {
// Create an EValue with ListBool tag but null pointer
EValue e;
e.tag = Tag::ListBool;
e.payload.copyable_union.as_bool_list_ptr = nullptr;

// Should pass isBoolList() check but fail null pointer check
EXPECT_TRUE(e.isBoolList());
ET_EXPECT_DEATH({ e.toBoolList(); }, "");
}

TEST_F(EValueTest, toDoubleListNullPointerCheck) {
// Create an EValue with ListDouble tag but null pointer
EValue e;
e.tag = Tag::ListDouble;
e.payload.copyable_union.as_double_list_ptr = nullptr;

// Should pass isDoubleList() check but fail null pointer check
EXPECT_TRUE(e.isDoubleList());
ET_EXPECT_DEATH({ e.toDoubleList(); }, "");
}

TEST_F(EValueTest, toTensorListNullPointerCheck) {
// Create an EValue with ListTensor tag but null pointer
EValue e;
e.tag = Tag::ListTensor;
e.payload.copyable_union.as_tensor_list_ptr = nullptr;

// Should pass isTensorList() check but fail null pointer check
EXPECT_TRUE(e.isTensorList());
ET_EXPECT_DEATH({ e.toTensorList(); }, "");
}

TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) {
// Create an EValue with ListOptionalTensor tag but null pointer
EValue e;
e.tag = Tag::ListOptionalTensor;
e.payload.copyable_union.as_list_optional_tensor_ptr = nullptr;

// Should pass isListOptionalTensor() check but fail null pointer check
EXPECT_TRUE(e.isListOptionalTensor());
ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, "");
}
Loading