Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 30 additions & 7 deletions source/common/http/http1/balsa_parser.cc
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
#include "source/common/http/http1/balsa_parser.h"

#include <algorithm>
#include <array>
#include <cctype>
#include <cstddef>
#include <cstdint>

#include "source/common/common/assert.h"
Expand All @@ -25,21 +27,42 @@ constexpr absl::string_view kColonSlashSlash = "://";
constexpr char kResponseFirstByte = 'H';
constexpr absl::string_view kHttpVersionPrefix = "HTTP/";

// Allowed characters for field names according to Section 5.1
// and for methods according to Section 9.1 of RFC 9110:
// RFC 9110 Sections 5.1 and 9.1 define field names and methods as tokens:
// https://www.rfc-editor.org/rfc/rfc9110.html
constexpr absl::string_view kValidCharacters =
constexpr char kValidCharacters[] =
"!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~";
constexpr absl::string_view::iterator kValidCharactersBegin = kValidCharacters.begin();
constexpr absl::string_view::iterator kValidCharactersEnd = kValidCharacters.end();

constexpr std::array<uint64_t, 4> makeValidCharacterMask() {
std::array<uint64_t, 4> mask{};
for (size_t i = 0; i < sizeof(kValidCharacters) - 1; ++i) {
const uint8_t index = static_cast<uint8_t>(kValidCharacters[i]);
mask[index / 64] |= 1ULL << (index % 64);
}
return mask;
}

// This keeps the per-character hot path branch-light and avoids a binary search through the valid
// character list for every byte in every HTTP/1 header name.
constexpr std::array<uint64_t, 4> kValidCharacterMask = makeValidCharacterMask();

constexpr bool isValidTokenCharacter(char c) {
const uint8_t index = static_cast<uint8_t>(c);
return (kValidCharacterMask[index / 64] & (1ULL << (index % 64))) != 0;
}

static_assert(isValidTokenCharacter('a'));
static_assert(isValidTokenCharacter('Z'));
static_assert(isValidTokenCharacter('-'));
static_assert(!isValidTokenCharacter(':'));
static_assert(!isValidTokenCharacter(' '));

// TODO(#21245): Skip method validation altogether when UHV method validation is
// enabled.
bool isMethodValid(absl::string_view method, bool allow_custom_methods) {
if (allow_custom_methods) {
return !method.empty() &&
std::all_of(method.begin(), method.end(), [](absl::string_view::value_type c) {
return std::binary_search(kValidCharactersBegin, kValidCharactersEnd, c);
return isValidTokenCharacter(c);
});
}

Expand Down Expand Up @@ -132,7 +155,7 @@ bool isVersionValid(absl::string_view version_input) {

bool isHeaderNameValid(absl::string_view name) {
return std::all_of(name.begin(), name.end(), [](absl::string_view::value_type c) {
return std::binary_search(kValidCharactersBegin, kValidCharactersEnd, c);
return isValidTokenCharacter(c);
});
}

Expand Down
26 changes: 26 additions & 0 deletions test/common/http/http1/BUILD
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
load(
"//bazel:envoy_build_system.bzl",
"envoy_benchmark_test",
"envoy_cc_benchmark_binary",
"envoy_cc_fuzz_test",
"envoy_cc_test",
"envoy_package",
Expand Down Expand Up @@ -46,6 +48,30 @@ envoy_cc_test(
],
)

envoy_cc_test(
name = "balsa_parser_test",
srcs = ["balsa_parser_test.cc"],
rbe_pool = "6gig",
deps = [
"//source/common/http/http1:balsa_parser_lib",
],
)

envoy_cc_benchmark_binary(
name = "balsa_parser_benchmark",
srcs = ["balsa_parser_benchmark_test.cc"],
rbe_pool = "6gig",
deps = [
"//source/common/http/http1:balsa_parser_lib",
"@benchmark",
],
)

envoy_benchmark_test(
name = "balsa_parser_benchmark_test",
benchmark_binary = "balsa_parser_benchmark",
)

envoy_cc_test(
name = "conn_pool_test",
srcs = ["conn_pool_test.cc"],
Expand Down
126 changes: 126 additions & 0 deletions test/common/http/http1/balsa_parser_benchmark_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
#include "source/common/http/http1/balsa_parser.h"

#include <cstddef>
#include <cstdint>
#include <string>

#include "benchmark/benchmark.h"

namespace Envoy {
namespace Http {
namespace Http1 {
namespace {

constexpr char kValidHttpTokenCharacters[] =
"!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~";

enum class HeaderNameShape : int {
LongTokenName = 0,
ShortName = 1,
FullTokenSetName = 2,
};

const char* headerNameShapeLabel(const HeaderNameShape shape) {
switch (shape) {
case HeaderNameShape::LongTokenName:
return "long";
case HeaderNameShape::ShortName:
return "short";
case HeaderNameShape::FullTokenSetName:
return "full-token-set";
}
return "unknown";
}

class HeaderCountingCallbacks : public ParserCallbacks {
public:
CallbackResult onMessageBegin() override { return CallbackResult::Success; }
CallbackResult onUrl(const char*, size_t length) override {
bytes_seen_ += length;
return CallbackResult::Success;
}
CallbackResult onStatus(const char*, size_t length) override {
bytes_seen_ += length;
return CallbackResult::Success;
}
CallbackResult onHeaderField(const char*, size_t length) override {
bytes_seen_ += length;
++headers_seen_;
return CallbackResult::Success;
}
CallbackResult onHeaderValue(const char*, size_t length) override {
bytes_seen_ += length;
return CallbackResult::Success;
}
CallbackResult onHeadersComplete() override { return CallbackResult::Success; }
void bufferBody(const char*, size_t length) override { bytes_seen_ += length; }
CallbackResult onMessageComplete() override { return CallbackResult::Success; }
void onChunkHeader(bool) override {}

size_t headersSeen() const { return headers_seen_; }
size_t bytesSeen() const { return bytes_seen_; }

private:
size_t headers_seen_{};
size_t bytes_seen_{};
};

std::string makeHeaderName(const int index, const HeaderNameShape shape) {
switch (shape) {
case HeaderNameShape::LongTokenName:
return "x-benchmark-header-" + std::to_string(index) + "-token-name";
case HeaderNameShape::ShortName:
return "x" + std::to_string(index);
case HeaderNameShape::FullTokenSetName:
return "x" + std::to_string(index) + "-" + kValidHttpTokenCharacters;
}
return "x-unknown";
}

std::string makeRequest(const int header_count, const HeaderNameShape shape) {
std::string request = "GET /benchmark HTTP/1.1\r\n";
request.reserve(32 + static_cast<size_t>(header_count) * 96);
for (int i = 0; i < header_count; ++i) {
request.append(makeHeaderName(i, shape));
request.append(": value\r\n");
}
request.append("\r\n");
return request;
}

void bmParseHeaders(benchmark::State& state) {
const int header_count = state.range(0);
const HeaderNameShape shape = static_cast<HeaderNameShape>(state.range(1));
const std::string request = makeRequest(header_count, shape);
state.SetLabel(headerNameShapeLabel(shape));
{
HeaderCountingCallbacks callbacks;
BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false);
const size_t parsed = parser.execute(request.data(), request.size());
if (parsed != request.size() || parser.getStatus() != ParserStatus::Ok ||
callbacks.headersSeen() != static_cast<size_t>(header_count)) {
state.SkipWithError("benchmark request failed to parse");
return;
}
}

for (auto _ : state) {
HeaderCountingCallbacks callbacks;
BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false);
const size_t parsed = parser.execute(request.data(), request.size());
benchmark::DoNotOptimize(parsed);
benchmark::DoNotOptimize(callbacks.headersSeen());
benchmark::DoNotOptimize(callbacks.bytesSeen());
}
state.SetItemsProcessed(state.iterations() * header_count);
state.SetBytesProcessed(state.iterations() * static_cast<int64_t>(request.size()));
}

BENCHMARK(bmParseHeaders)
->ArgsProduct({{8, 16, 64, 256, 512}, {0, 1, 2}})
->ArgNames({"headers", "shape"});

} // namespace
} // namespace Http1
} // namespace Http
} // namespace Envoy
106 changes: 106 additions & 0 deletions test/common/http/http1/balsa_parser_test.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
#include "source/common/http/http1/balsa_parser.h"

#include <array>
#include <string>
#include <vector>

#include "gtest/gtest.h"

namespace Envoy {
namespace Http {
namespace Http1 {
namespace {

constexpr char kValidHttpTokenCharacters[] =
"!#$%&'*+-.0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ^_`abcdefghijklmnopqrstuvwxyz|~";

class RecordingCallbacks : public ParserCallbacks {
public:
CallbackResult onMessageBegin() override { return CallbackResult::Success; }
CallbackResult onUrl(const char*, size_t) override { return CallbackResult::Success; }
CallbackResult onStatus(const char*, size_t) override { return CallbackResult::Success; }
CallbackResult onHeaderField(const char* data, size_t length) override {
header_names_.emplace_back(data, length);
return CallbackResult::Success;
}
CallbackResult onHeaderValue(const char*, size_t) override { return CallbackResult::Success; }
CallbackResult onHeadersComplete() override { return CallbackResult::Success; }
void bufferBody(const char*, size_t) override {}
CallbackResult onMessageComplete() override { return CallbackResult::Success; }
void onChunkHeader(bool) override {}

const std::vector<std::string>& headerNames() const { return header_names_; }

private:
std::vector<std::string> header_names_;
};

TEST(BalsaParserHeaderNameValidationTest, ParsesHeaderNameContainingEveryTokenCharacter) {
std::string request = "GET / HTTP/1.1\r\n";
request += kValidHttpTokenCharacters;
request += ": value\r\n\r\n";

RecordingCallbacks callbacks;
BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false);

EXPECT_EQ(request.size(), parser.execute(request.data(), request.size()));
EXPECT_EQ(ParserStatus::Ok, parser.getStatus()) << parser.errorMessage();
ASSERT_EQ(1, callbacks.headerNames().size());
EXPECT_EQ(kValidHttpTokenCharacters, callbacks.headerNames()[0]);
}

TEST(BalsaParserHeaderNameValidationTest, RejectsInvalidTokenCharactersInHeaderNames) {
constexpr std::array<unsigned char, 21> invalid_characters = {
0x00, ' ', '"', '(', ')', ',', '/', ';', '<', '=',
'>', '?', '@', '[', '\\', ']', '{', '}', 0x7f, 0x80, 0xff};

for (const unsigned char c : invalid_characters) {
SCOPED_TRACE(testing::Message() << "character code " << static_cast<int>(c));
std::string request = "GET / HTTP/1.1\r\nbad";
request.push_back(static_cast<char>(c));
request += "name: value\r\n\r\n";

RecordingCallbacks callbacks;
BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false);

parser.execute(request.data(), request.size());

EXPECT_EQ(ParserStatus::Error, parser.getStatus());
EXPECT_FALSE(parser.errorMessage().empty());
EXPECT_TRUE(callbacks.headerNames().empty());
}
}

TEST(BalsaParserHeaderNameValidationTest, RejectsHighBitHeaderNameCharacterAfterParsing) {
std::string request = "GET / HTTP/1.1\r\nf";
request.push_back(static_cast<char>(0xc3));
request.push_back(static_cast<char>(0xb6));
request += "o: value\r\n\r\n";

RecordingCallbacks callbacks;
BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, false);

parser.execute(request.data(), request.size());

EXPECT_EQ(ParserStatus::Error, parser.getStatus());
EXPECT_EQ("HPE_INVALID_HEADER_TOKEN", parser.errorMessage());
EXPECT_TRUE(callbacks.headerNames().empty());
}

TEST(BalsaParserMethodValidationTest, ParsesCustomMethodContainingEveryTokenCharacter) {
std::string request;
request += kValidHttpTokenCharacters;
request += " / HTTP/1.1\r\nhost: example.com\r\n\r\n";

RecordingCallbacks callbacks;
BalsaParser parser(MessageType::Request, &callbacks, request.size(), false, true);

EXPECT_EQ(request.size(), parser.execute(request.data(), request.size()));
EXPECT_EQ(ParserStatus::Ok, parser.getStatus()) << parser.errorMessage();
EXPECT_EQ(kValidHttpTokenCharacters, parser.methodName());
}

} // namespace
} // namespace Http1
} // namespace Http
} // namespace Envoy