diff --git a/runtime/core/evalue.h b/runtime/core/evalue.h index 0cea86dc30c..18927aba9cd 100644 --- a/runtime/core/evalue.h +++ b/runtime/core/evalue.h @@ -63,13 +63,26 @@ class BoxedEvalueList { * unwrapped vals. */ BoxedEvalueList(EValue** wrapped_vals, T* unwrapped_vals, int size) - : wrapped_vals_(wrapped_vals, size), unwrapped_vals_(unwrapped_vals) {} + : wrapped_vals_(checkWrappedVals(wrapped_vals, size), size), + unwrapped_vals_(checkUnwrappedVals(unwrapped_vals)) {} + /* * Constructs and returns the list of T specified by the EValue pointers */ executorch::aten::ArrayRef get() const; private: + static EValue** checkWrappedVals(EValue** wrapped_vals, int size) { + ET_CHECK_MSG(wrapped_vals != nullptr, "wrapped_vals cannot be null"); + ET_CHECK_MSG(size >= 0, "size cannot be negative"); + return wrapped_vals; + } + + static T* checkUnwrappedVals(T* unwrapped_vals) { + ET_CHECK_MSG(unwrapped_vals != nullptr, "unwrapped_vals cannot be null"); + return unwrapped_vals; + } + // Source of truth for the list executorch::aten::ArrayRef wrapped_vals_; // Same size as wrapped_vals @@ -280,6 +293,7 @@ struct EValue { /****** String Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* s) : tag(Tag::String) { + ET_CHECK_MSG(s != nullptr, "ArrayRef pointer cannot be null"); payload.copyable_union.as_string_ptr = s; } @@ -289,6 +303,9 @@ struct EValue { std::string_view toString() const { ET_CHECK_MSG(isString(), "EValue is not a String."); + ET_CHECK_MSG( + payload.copyable_union.as_string_ptr != nullptr, + "EValue string pointer is null."); return std::string_view( payload.copyable_union.as_string_ptr->data(), payload.copyable_union.as_string_ptr->size()); @@ -296,6 +313,8 @@ struct EValue { /****** Int List Type ******/ /*implicit*/ EValue(BoxedEvalueList* i) : tag(Tag::ListInt) { + ET_CHECK_MSG( + i != nullptr, "BoxedEvalueList pointer cannot be null"); payload.copyable_union.as_int_list_ptr = i; } @@ -305,12 +324,16 @@ struct EValue { executorch::aten::ArrayRef toIntList() const { ET_CHECK_MSG(isIntList(), "EValue is not an Int List."); + ET_CHECK_MSG( + payload.copyable_union.as_int_list_ptr != nullptr, + "EValue int list pointer is null."); return (payload.copyable_union.as_int_list_ptr)->get(); } /****** Bool List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* b) : tag(Tag::ListBool) { + ET_CHECK_MSG(b != nullptr, "ArrayRef pointer cannot be null"); payload.copyable_union.as_bool_list_ptr = b; } @@ -320,12 +343,16 @@ struct EValue { executorch::aten::ArrayRef toBoolList() const { ET_CHECK_MSG(isBoolList(), "EValue is not a Bool List."); + ET_CHECK_MSG( + payload.copyable_union.as_bool_list_ptr != nullptr, + "EValue bool list pointer is null."); return *(payload.copyable_union.as_bool_list_ptr); } /****** Double List Type ******/ /*implicit*/ EValue(executorch::aten::ArrayRef* d) : tag(Tag::ListDouble) { + ET_CHECK_MSG(d != nullptr, "ArrayRef pointer cannot be null"); payload.copyable_union.as_double_list_ptr = d; } @@ -335,12 +362,17 @@ struct EValue { executorch::aten::ArrayRef toDoubleList() const { ET_CHECK_MSG(isDoubleList(), "EValue is not a Double List."); + ET_CHECK_MSG( + payload.copyable_union.as_double_list_ptr != nullptr, + "EValue double list pointer is null."); return *(payload.copyable_union.as_double_list_ptr); } /****** Tensor List Type ******/ /*implicit*/ EValue(BoxedEvalueList* t) : tag(Tag::ListTensor) { + ET_CHECK_MSG( + t != nullptr, "BoxedEvalueList pointer cannot be null"); payload.copyable_union.as_tensor_list_ptr = t; } @@ -350,6 +382,9 @@ struct EValue { executorch::aten::ArrayRef toTensorList() const { ET_CHECK_MSG(isTensorList(), "EValue is not a Tensor List."); + ET_CHECK_MSG( + payload.copyable_union.as_tensor_list_ptr != nullptr, + "EValue tensor list pointer is null."); return payload.copyable_union.as_tensor_list_ptr->get(); } @@ -357,6 +392,9 @@ struct EValue { /*implicit*/ EValue( BoxedEvalueList>* t) : tag(Tag::ListOptionalTensor) { + ET_CHECK_MSG( + t != nullptr, + "BoxedEvalueList> pointer cannot be null"); payload.copyable_union.as_list_optional_tensor_ptr = t; } @@ -366,6 +404,11 @@ struct EValue { executorch::aten::ArrayRef> toListOptionalTensor() const { + ET_CHECK_MSG( + isListOptionalTensor(), "EValue is not a List Optional Tensor."); + ET_CHECK_MSG( + payload.copyable_union.as_list_optional_tensor_ptr != nullptr, + "EValue list optional tensor pointer is null."); return payload.copyable_union.as_list_optional_tensor_ptr->get(); } @@ -445,12 +488,19 @@ struct EValue { // minor performance bump for a code maintainability hit if (isTensor()) { payload.as_tensor.~Tensor(); - } else if (isTensorList()) { - for (auto& tensor : toTensorList()) { + } else if ( + isTensorList() && + payload.copyable_union.as_tensor_list_ptr != nullptr) { + // for (auto& tensor : toTensorList()) { + for (auto& tensor : payload.copyable_union.as_tensor_list_ptr->get()) { tensor.~Tensor(); } - } else if (isListOptionalTensor()) { - for (auto& optional_tensor : toListOptionalTensor()) { + } else if ( + isListOptionalTensor() && + payload.copyable_union.as_list_optional_tensor_ptr != nullptr) { + // for (auto& optional_tensor : toListOptionalTensor()) { + for (auto& optional_tensor : + payload.copyable_union.as_list_optional_tensor_ptr->get()) { optional_tensor.~optional(); } } diff --git a/runtime/core/test/evalue_test.cpp b/runtime/core/test/evalue_test.cpp index f04745187bb..9e91ad70a0b 100644 --- a/runtime/core/test/evalue_test.cpp +++ b/runtime/core/test/evalue_test.cpp @@ -281,3 +281,130 @@ TEST_F(EValueTest, ConstructFromNullPtrAborts) { ET_EXPECT_DEATH({ EValue evalue(null_ptr); }, ""); } + +TEST_F(EValueTest, StringConstructorNullCheck) { + executorch::aten::ArrayRef* null_string_ptr = nullptr; + ET_EXPECT_DEATH({ EValue evalue(null_string_ptr); }, ""); +} + +TEST_F(EValueTest, BoolListConstructorNullCheck) { + executorch::aten::ArrayRef* null_bool_list_ptr = nullptr; + ET_EXPECT_DEATH({ EValue evalue(null_bool_list_ptr); }, ""); +} + +TEST_F(EValueTest, DoubleListConstructorNullCheck) { + executorch::aten::ArrayRef* null_double_list_ptr = nullptr; + ET_EXPECT_DEATH({ EValue evalue(null_double_list_ptr); }, ""); +} + +TEST_F(EValueTest, IntListConstructorNullCheck) { + BoxedEvalueList* null_int_list_ptr = nullptr; + ET_EXPECT_DEATH({ EValue evalue(null_int_list_ptr); }, ""); +} + +TEST_F(EValueTest, TensorListConstructorNullCheck) { + BoxedEvalueList* null_tensor_list_ptr = nullptr; + ET_EXPECT_DEATH({ EValue evalue(null_tensor_list_ptr); }, ""); +} + +TEST_F(EValueTest, OptionalTensorListConstructorNullCheck) { + BoxedEvalueList>* + null_optional_tensor_list_ptr = nullptr; + ET_EXPECT_DEATH({ EValue evalue(null_optional_tensor_list_ptr); }, ""); +} + +TEST_F(EValueTest, BoxedEvalueListConstructorNullChecks) { + std::array storage = {0, 0, 0}; + std::array values = { + EValue((int64_t)1), EValue((int64_t)2), EValue((int64_t)3)}; + std::array values_p = {&values[0], &values[1], &values[2]}; + + // Test null wrapped_vals + ET_EXPECT_DEATH( + { BoxedEvalueList list(nullptr, storage.data(), 3); }, ""); + + // Test null unwrapped_vals + ET_EXPECT_DEATH( + { BoxedEvalueList list(values_p.data(), nullptr, 3); }, ""); + + // Test negative size + ET_EXPECT_DEATH( + { BoxedEvalueList list(values_p.data(), storage.data(), -1); }, + ""); +} + +TEST_F(EValueTest, toListOptionalTensorTypeCheck) { + // Create an EValue that's not a ListOptionalTensor + EValue e((int64_t)42); + EXPECT_TRUE(e.isInt()); + EXPECT_FALSE(e.isListOptionalTensor()); + + // Should fail type check + ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, ""); +} + +TEST_F(EValueTest, toStringNullPointerCheck) { + // Create an EValue with String tag but null pointer + EValue e; + e.tag = Tag::String; + e.payload.copyable_union.as_string_ptr = nullptr; + + // Should pass isString() check but fail null pointer check + EXPECT_TRUE(e.isString()); + ET_EXPECT_DEATH({ e.toString(); }, ""); +} + +TEST_F(EValueTest, toIntListNullPointerCheck) { + // Create an EValue with ListInt tag but null pointer + EValue e; + e.tag = Tag::ListInt; + e.payload.copyable_union.as_int_list_ptr = nullptr; + + // Should pass isIntList() check but fail null pointer check + EXPECT_TRUE(e.isIntList()); + ET_EXPECT_DEATH({ e.toIntList(); }, ""); +} + +TEST_F(EValueTest, toBoolListNullPointerCheck) { + // Create an EValue with ListBool tag but null pointer + EValue e; + e.tag = Tag::ListBool; + e.payload.copyable_union.as_bool_list_ptr = nullptr; + + // Should pass isBoolList() check but fail null pointer check + EXPECT_TRUE(e.isBoolList()); + ET_EXPECT_DEATH({ e.toBoolList(); }, ""); +} + +TEST_F(EValueTest, toDoubleListNullPointerCheck) { + // Create an EValue with ListDouble tag but null pointer + EValue e; + e.tag = Tag::ListDouble; + e.payload.copyable_union.as_double_list_ptr = nullptr; + + // Should pass isDoubleList() check but fail null pointer check + EXPECT_TRUE(e.isDoubleList()); + ET_EXPECT_DEATH({ e.toDoubleList(); }, ""); +} + +TEST_F(EValueTest, toTensorListNullPointerCheck) { + // Create an EValue with ListTensor tag but null pointer + EValue e; + e.tag = Tag::ListTensor; + e.payload.copyable_union.as_tensor_list_ptr = nullptr; + + // Should pass isTensorList() check but fail null pointer check + EXPECT_TRUE(e.isTensorList()); + ET_EXPECT_DEATH({ e.toTensorList(); }, ""); +} + +TEST_F(EValueTest, toListOptionalTensorNullPointerCheck) { + // Create an EValue with ListOptionalTensor tag but null pointer + EValue e; + e.tag = Tag::ListOptionalTensor; + e.payload.copyable_union.as_list_optional_tensor_ptr = nullptr; + + // Should pass isListOptionalTensor() check but fail null pointer check + EXPECT_TRUE(e.isListOptionalTensor()); + ET_EXPECT_DEATH({ e.toListOptionalTensor(); }, ""); +}