diff --git a/source/common/http/http1/balsa_parser.cc b/source/common/http/http1/balsa_parser.cc index 583db79fbfd44..84150fb7f6d2b 100644 --- a/source/common/http/http1/balsa_parser.cc +++ b/source/common/http/http1/balsa_parser.cc @@ -1,7 +1,9 @@ #include "source/common/http/http1/balsa_parser.h" #include +#include #include +#include #include #include "source/common/common/assert.h" @@ -25,13 +27,34 @@ constexpr absl::string_view kColonSlashSlash = "://"; constexpr char kResponseFirstByte = 'H'; constexpr absl::string_view kHttpVersionPrefix = "HTTP/"; -// Allowed characters for field names according to Section 5.1 -// and for methods according to Section 9.1 of RFC 9110: +// RFC 9110 Sections 5.1 and 9.1 define field names and methods as tokens: // https://www.rfc-editor.org/rfc/rfc9110.html -constexpr absl::string_view kValidCharacters = +constexpr char kValidCharacters[] = "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"; -constexpr absl::string_view::iterator kValidCharactersBegin = kValidCharacters.begin(); -constexpr absl::string_view::iterator kValidCharactersEnd = kValidCharacters.end(); + +constexpr std::array makeValidCharacterMask() { + std::array mask{}; + for (size_t i = 0; i < sizeof(kValidCharacters) - 1; ++i) { + const uint8_t index = static_cast(kValidCharacters[i]); + mask[index / 64] |= 1ULL << (index % 64); + } + return mask; +} + +// This keeps the per-character hot path branch-light and avoids a binary search through the valid +// character list for every byte in every HTTP/1 header name. +constexpr std::array kValidCharacterMask = makeValidCharacterMask(); + +constexpr bool isValidTokenCharacter(char c) { + const uint8_t index = static_cast(c); + return (kValidCharacterMask[index / 64] & (1ULL << (index % 64))) != 0; +} + +static_assert(isValidTokenCharacter('a')); +static_assert(isValidTokenCharacter('Z')); +static_assert(isValidTokenCharacter('-')); +static_assert(!isValidTokenCharacter(':')); +static_assert(!isValidTokenCharacter(' ')); // TODO(#21245): Skip method validation altogether when UHV method validation is // enabled. @@ -39,7 +62,7 @@ bool isMethodValid(absl::string_view method, bool allow_custom_methods) { if (allow_custom_methods) { return !method.empty() && std::all_of(method.begin(), method.end(), [](absl::string_view::value_type c) { - return std::binary_search(kValidCharactersBegin, kValidCharactersEnd, c); + return isValidTokenCharacter(c); }); } @@ -132,7 +155,7 @@ bool isVersionValid(absl::string_view version_input) { bool isHeaderNameValid(absl::string_view name) { return std::all_of(name.begin(), name.end(), [](absl::string_view::value_type c) { - return std::binary_search(kValidCharactersBegin, kValidCharactersEnd, c); + return isValidTokenCharacter(c); }); } diff --git a/test/common/http/http1/BUILD b/test/common/http/http1/BUILD index 017acc2874cc6..21e417fda1913 100644 --- a/test/common/http/http1/BUILD +++ b/test/common/http/http1/BUILD @@ -1,5 +1,7 @@ load( "//bazel:envoy_build_system.bzl", + "envoy_benchmark_test", + "envoy_cc_benchmark_binary", "envoy_cc_fuzz_test", "envoy_cc_test", "envoy_package", @@ -46,6 +48,30 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "balsa_parser_test", + srcs = ["balsa_parser_test.cc"], + rbe_pool = "6gig", + deps = [ + "//source/common/http/http1:balsa_parser_lib", + ], +) + +envoy_cc_benchmark_binary( + name = "balsa_parser_benchmark", + srcs = ["balsa_parser_benchmark_test.cc"], + rbe_pool = "6gig", + deps = [ + "//source/common/http/http1:balsa_parser_lib", + "@benchmark", + ], +) + +envoy_benchmark_test( + name = "balsa_parser_benchmark_test", + benchmark_binary = "balsa_parser_benchmark", +) + envoy_cc_test( name = "conn_pool_test", srcs = ["conn_pool_test.cc"], diff --git a/test/common/http/http1/balsa_parser_benchmark_test.cc b/test/common/http/http1/balsa_parser_benchmark_test.cc new file mode 100644 index 0000000000000..15830f9485e74 --- /dev/null +++ b/test/common/http/http1/balsa_parser_benchmark_test.cc @@ -0,0 +1,126 @@ +#include "source/common/http/http1/balsa_parser.h" + +#include +#include +#include + +#include "benchmark/benchmark.h" + +namespace Envoy { +namespace Http { +namespace Http1 { +namespace { + +constexpr char kValidHttpTokenCharacters[] = + "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"; + +enum class HeaderNameShape : int { + LongTokenName = 0, + ShortName = 1, + FullTokenSetName = 2, +}; + +const char* headerNameShapeLabel(const HeaderNameShape shape) { + switch (shape) { + case HeaderNameShape::LongTokenName: + return "long"; + case HeaderNameShape::ShortName: + return "short"; + case HeaderNameShape::FullTokenSetName: + return "full-token-set"; + } + return "unknown"; +} + +class HeaderCountingCallbacks : public ParserCallbacks { +public: + CallbackResult onMessageBegin() override { return CallbackResult::Success; } + CallbackResult onUrl(const char*, size_t length) override { + bytes_seen_ += length; + return CallbackResult::Success; + } + CallbackResult onStatus(const char*, size_t length) override { + bytes_seen_ += length; + return CallbackResult::Success; + } + CallbackResult onHeaderField(const char*, size_t length) override { + bytes_seen_ += length; + ++headers_seen_; + return CallbackResult::Success; + } + CallbackResult onHeaderValue(const char*, size_t length) override { + bytes_seen_ += length; + return CallbackResult::Success; + } + CallbackResult onHeadersComplete() override { return CallbackResult::Success; } + void bufferBody(const char*, size_t length) override { bytes_seen_ += length; } + CallbackResult onMessageComplete() override { return CallbackResult::Success; } + void onChunkHeader(bool) override {} + + size_t headersSeen() const { return headers_seen_; } + size_t bytesSeen() const { return bytes_seen_; } + +private: + size_t headers_seen_{}; + size_t bytes_seen_{}; +}; + +std::string makeHeaderName(const int index, const HeaderNameShape shape) { + switch (shape) { + case HeaderNameShape::LongTokenName: + return "x-benchmark-header-" + std::to_string(index) + "-token-name"; + case HeaderNameShape::ShortName: + return "x" + std::to_string(index); + case HeaderNameShape::FullTokenSetName: + return "x" + std::to_string(index) + "-" + kValidHttpTokenCharacters; + } + return "x-unknown"; +} + +std::string makeRequest(const int header_count, const HeaderNameShape shape) { + std::string request = "GET /benchmark HTTP/1.1\r\n"; + request.reserve(32 + static_cast(header_count) * 96); + for (int i = 0; i < header_count; ++i) { + request.append(makeHeaderName(i, shape)); + request.append(": value\r\n"); + } + request.append("\r\n"); + return request; +} + +void bmParseHeaders(benchmark::State& state) { + const int header_count = state.range(0); + const HeaderNameShape shape = static_cast(state.range(1)); + const std::string request = makeRequest(header_count, shape); + state.SetLabel(headerNameShapeLabel(shape)); + { + HeaderCountingCallbacks callbacks; + BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false); + const size_t parsed = parser.execute(request.data(), request.size()); + if (parsed != request.size() || parser.getStatus() != ParserStatus::Ok || + callbacks.headersSeen() != static_cast(header_count)) { + state.SkipWithError("benchmark request failed to parse"); + return; + } + } + + for (auto _ : state) { + HeaderCountingCallbacks callbacks; + BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false); + const size_t parsed = parser.execute(request.data(), request.size()); + benchmark::DoNotOptimize(parsed); + benchmark::DoNotOptimize(callbacks.headersSeen()); + benchmark::DoNotOptimize(callbacks.bytesSeen()); + } + state.SetItemsProcessed(state.iterations() * header_count); + state.SetBytesProcessed(state.iterations() * static_cast(request.size())); +} + +BENCHMARK(bmParseHeaders) + ->ArgsProduct({{8, 16, 64, 256, 512}, {0, 1, 2}}) + ->ArgNames({"headers", "shape"}); + +} // namespace +} // namespace Http1 +} // namespace Http +} // namespace Envoy diff --git a/test/common/http/http1/balsa_parser_test.cc b/test/common/http/http1/balsa_parser_test.cc new file mode 100644 index 0000000000000..1e1a45658e522 --- /dev/null +++ b/test/common/http/http1/balsa_parser_test.cc @@ -0,0 +1,106 @@ +#include "source/common/http/http1/balsa_parser.h" + +#include +#include +#include + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Http { +namespace Http1 { +namespace { + +constexpr char kValidHttpTokenCharacters[] = + "!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~"; + +class RecordingCallbacks : public ParserCallbacks { +public: + CallbackResult onMessageBegin() override { return CallbackResult::Success; } + CallbackResult onUrl(const char*, size_t) override { return CallbackResult::Success; } + CallbackResult onStatus(const char*, size_t) override { return CallbackResult::Success; } + CallbackResult onHeaderField(const char* data, size_t length) override { + header_names_.emplace_back(data, length); + return CallbackResult::Success; + } + CallbackResult onHeaderValue(const char*, size_t) override { return CallbackResult::Success; } + CallbackResult onHeadersComplete() override { return CallbackResult::Success; } + void bufferBody(const char*, size_t) override {} + CallbackResult onMessageComplete() override { return CallbackResult::Success; } + void onChunkHeader(bool) override {} + + const std::vector& headerNames() const { return header_names_; } + +private: + std::vector header_names_; +}; + +TEST(BalsaParserHeaderNameValidationTest, ParsesHeaderNameContainingEveryTokenCharacter) { + std::string request = "GET / HTTP/1.1\r\n"; + request += kValidHttpTokenCharacters; + request += ": value\r\n\r\n"; + + RecordingCallbacks callbacks; + BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false); + + EXPECT_EQ(request.size(), parser.execute(request.data(), request.size())); + EXPECT_EQ(ParserStatus::Ok, parser.getStatus()) << parser.errorMessage(); + ASSERT_EQ(1, callbacks.headerNames().size()); + EXPECT_EQ(kValidHttpTokenCharacters, callbacks.headerNames()[0]); +} + +TEST(BalsaParserHeaderNameValidationTest, RejectsInvalidTokenCharactersInHeaderNames) { + constexpr std::array invalid_characters = { + 0x00, ' ', '"', '(', ')', ',', '/', ';', '<', '=', + '>', '?', '@', '[', '\\', ']', '{', '}', 0x7f, 0x80, 0xff}; + + for (const unsigned char c : invalid_characters) { + SCOPED_TRACE(testing::Message() << "character code " << static_cast(c)); + std::string request = "GET / HTTP/1.1\r\nbad"; + request.push_back(static_cast(c)); + request += "name: value\r\n\r\n"; + + RecordingCallbacks callbacks; + BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false); + + parser.execute(request.data(), request.size()); + + EXPECT_EQ(ParserStatus::Error, parser.getStatus()); + EXPECT_FALSE(parser.errorMessage().empty()); + EXPECT_TRUE(callbacks.headerNames().empty()); + } +} + +TEST(BalsaParserHeaderNameValidationTest, RejectsHighBitHeaderNameCharacterAfterParsing) { + std::string request = "GET / HTTP/1.1\r\nf"; + request.push_back(static_cast(0xc3)); + request.push_back(static_cast(0xb6)); + request += "o: value\r\n\r\n"; + + RecordingCallbacks callbacks; + BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false); + + parser.execute(request.data(), request.size()); + + EXPECT_EQ(ParserStatus::Error, parser.getStatus()); + EXPECT_EQ("HPE_INVALID_HEADER_TOKEN", parser.errorMessage()); + EXPECT_TRUE(callbacks.headerNames().empty()); +} + +TEST(BalsaParserMethodValidationTest, ParsesCustomMethodContainingEveryTokenCharacter) { + std::string request; + request += kValidHttpTokenCharacters; + request += " / HTTP/1.1\r\nhost: example.com\r\n\r\n"; + + RecordingCallbacks callbacks; + BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, true); + + EXPECT_EQ(request.size(), parser.execute(request.data(), request.size())); + EXPECT_EQ(ParserStatus::Ok, parser.getStatus()) << parser.errorMessage(); + EXPECT_EQ(kValidHttpTokenCharacters, parser.methodName()); +} + +} // namespace +} // namespace Http1 +} // namespace Http +} // namespace Envoy