Skip to content

Commit

Permalink
Standardize on Init vs InitForOverwrite for value vs default initiali…
Browse files Browse the repository at this point in the history
…zation

C++ is fun and has two notions of "default" initialization, `new T` and
`new T()`. These are default initialization and value initialization,
respectively.

They are identical except that POD types are uninit when
default-initialized and zero when value-initialized. InplaceVector
picked the safer option by default and called the other one
FooMaybeUninit.  Array is older and uses the less safe one (it's almost
always the one we want; we usually allocate an array to immediately fill
it).

While MaybeUninit does capture what you do with it, it is slightly
ambiguous, as seen in Array's internal implementation: uninitialized
could also mean we haven't gotten around to initialize it at all. I.e.
we need to use a function like std::uninitialized_value_construct_n
instead of normal functions in <algorithm>.

C++20 has std::make_unique and std::make_unique_for_overwrite to capture
the two. This seems as fine a naming convention as any, so switch to it.
Along the way, make the internal bssl::Array default to the safer one.
This lets us remove a couple of memset(0)'s.

Change-Id: I32cede231da051a854e6251e10b87f8e4dd06ee6
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72268
Reviewed-by: Nick Harper <[email protected]>
Commit-Queue: David Benjamin <[email protected]>
  • Loading branch information
davidben authored and Boringssl LUCI CQ committed Oct 22, 2024
1 parent 4f76523 commit ce572d6
Show file tree
Hide file tree
Showing 16 changed files with 74 additions and 58 deletions.
2 changes: 1 addition & 1 deletion ssl/d1_both.cc
Original file line number Diff line number Diff line change
Expand Up @@ -785,7 +785,7 @@ static int send_flight(SSL *ssl) {
dtls1_update_mtu(ssl);

Array<uint8_t> packet;
if (!packet.Init(ssl->d1->mtu)) {
if (!packet.InitForOverwrite(ssl->d1->mtu)) {
return -1;
}

Expand Down
2 changes: 1 addition & 1 deletion ssl/encrypted_client_hello.cc
Original file line number Diff line number Diff line change
Expand Up @@ -303,7 +303,7 @@ bool ssl_client_hello_decrypt(SSL_HANDSHAKE *hs, uint8_t *out_alert,
return false;
}
#else
if (!encoded.Init(payload.size())) {
if (!encoded.InitForOverwrite(payload.size())) {
*out_alert = SSL_AD_INTERNAL_ERROR;
return false;
}
Expand Down
16 changes: 8 additions & 8 deletions ssl/extensions.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ static bool tls1_check_duplicate_extensions(const CBS *cbs) {
}

Array<uint16_t> extension_types;
if (!extension_types.Init(num_extensions)) {
if (!extension_types.InitForOverwrite(num_extensions)) {
return false;
}

Expand Down Expand Up @@ -2526,7 +2526,7 @@ static bool parse_u16_array(const CBS *cbs, Array<uint16_t> *out) {
}

Array<uint16_t> ret;
if (!ret.Init(CBS_len(&copy) / 2)) {
if (!ret.InitForOverwrite(CBS_len(&copy) / 2)) {
return false;
}
for (size_t i = 0; i < ret.size(); i++) {
Expand Down Expand Up @@ -2878,7 +2878,7 @@ static bool cert_compression_parse_clienthello(SSL_HANDSHAKE *hs,

const size_t num_given_alg_ids = CBS_len(&alg_ids) / 2;
Array<uint16_t> given_alg_ids;
if (!given_alg_ids.Init(num_given_alg_ids)) {
if (!given_alg_ids.InitForOverwrite(num_given_alg_ids)) {
return false;
}

Expand Down Expand Up @@ -3352,7 +3352,7 @@ bool ssl_setup_extension_permutation(SSL_HANDSHAKE *hs) {
uint32_t seeds[kNumExtensions - 1];
Array<uint8_t> permutation;
if (!RAND_bytes(reinterpret_cast<uint8_t *>(seeds), sizeof(seeds)) ||
!permutation.Init(kNumExtensions)) {
!permutation.InitForOverwrite(kNumExtensions)) {
return false;
}
for (size_t i = 0; i < kNumExtensions; i++) {
Expand Down Expand Up @@ -3918,7 +3918,7 @@ static enum ssl_ticket_aead_result_t decrypt_ticket_with_cipher_ctx(
if (ciphertext.size() >= INT_MAX) {
return ssl_ticket_aead_ignore_ticket;
}
if (!plaintext.Init(ciphertext.size())) {
if (!plaintext.InitForOverwrite(ciphertext.size())) {
return ssl_ticket_aead_error;
}
int len1, len2;
Expand Down Expand Up @@ -4006,7 +4006,7 @@ static enum ssl_ticket_aead_result_t ssl_decrypt_ticket_with_method(
SSL_HANDSHAKE *hs, Array<uint8_t> *out, bool *out_renew_ticket,
Span<const uint8_t> ticket) {
Array<uint8_t> plaintext;
if (!plaintext.Init(ticket.size())) {
if (!plaintext.InitForOverwrite(ticket.size())) {
return ssl_ticket_aead_error;
}

Expand Down Expand Up @@ -4115,7 +4115,7 @@ enum ssl_ticket_aead_result_t ssl_process_ticket(
// Envoy's tests expect the session to have a session ID that matches the
// placeholder used by the client. It's unclear whether this is a good idea,
// but we maintain it for now.
session->session_id.ResizeMaybeUninit(SHA256_DIGEST_LENGTH);
session->session_id.ResizeForOverwrite(SHA256_DIGEST_LENGTH);
SHA256(ticket.data(), ticket.size(), session->session_id.data());

*out_session = std::move(session);
Expand Down Expand Up @@ -4356,7 +4356,7 @@ bool tls1_record_handshake_hashes_for_channel_id(SSL_HANDSHAKE *hs) {
}

size_t digest_len;
hs->new_session->original_handshake_hash.ResizeMaybeUninit(
hs->new_session->original_handshake_hash.ResizeForOverwrite(
hs->transcript.DigestLen());
if (!hs->transcript.GetHash(hs->new_session->original_handshake_hash.data(),
&digest_len)) {
Expand Down
4 changes: 2 additions & 2 deletions ssl/handoff.cc
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ static bool apply_remote_features(SSL *ssl, CBS *in) {
return false;
}
Array<uint16_t> supported_groups;
if (!supported_groups.Init(CBS_len(&groups) / 2)) {
if (!supported_groups.InitForOverwrite(CBS_len(&groups) / 2)) {
return false;
}
size_t idx = 0;
Expand All @@ -190,7 +190,7 @@ static bool apply_remote_features(SSL *ssl, CBS *in) {
Span<const uint16_t> configured_groups =
tls1_get_grouplist(ssl->s3->hs.get());
Array<uint16_t> new_configured_groups;
if (!new_configured_groups.Init(configured_groups.size())) {
if (!new_configured_groups.InitForOverwrite(configured_groups.size())) {
return false;
}
idx = 0;
Expand Down
14 changes: 6 additions & 8 deletions ssl/handshake_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ static enum ssl_hs_wait_t do_start_connect(SSL_HANDSHAKE *hs) {
if (has_id_session) {
hs->session_id = ssl->session->session_id;
} else if (ticket_session_requires_random_id || enable_compatibility_mode) {
hs->session_id.ResizeMaybeUninit(SSL_MAX_SSL_SESSION_ID_LENGTH);
hs->session_id.ResizeForOverwrite(SSL_MAX_SSL_SESSION_ID_LENGTH);
if (!RAND_bytes(hs->session_id.data(), hs->session_id.size())) {
return ssl_hs_error;
}
Expand Down Expand Up @@ -1528,16 +1528,15 @@ static enum ssl_hs_wait_t do_send_client_key_exchange(SSL_HANDSHAKE *hs) {

// Depending on the key exchange method, compute |pms|.
if (alg_k & SSL_kRSA) {
if (!pms.Init(SSL_MAX_MASTER_KEY_LENGTH)) {
return ssl_hs_error;
}

RSA *rsa = EVP_PKEY_get0_RSA(hs->peer_pubkey.get());
if (rsa == NULL) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
return ssl_hs_error;
}

if (!pms.InitForOverwrite(SSL_MAX_MASTER_KEY_LENGTH)) {
return ssl_hs_error;
}
pms[0] = hs->client_version >> 8;
pms[1] = hs->client_version & 0xff;
if (!RAND_bytes(&pms[2], SSL_MAX_MASTER_KEY_LENGTH - 2)) {
Expand Down Expand Up @@ -1581,7 +1580,6 @@ static enum ssl_hs_wait_t do_send_client_key_exchange(SSL_HANDSHAKE *hs) {
if (!pms.Init(psk_len)) {
return ssl_hs_error;
}
OPENSSL_memset(pms.data(), 0, pms.size());
} else {
ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
Expand Down Expand Up @@ -1609,7 +1607,7 @@ static enum ssl_hs_wait_t do_send_client_key_exchange(SSL_HANDSHAKE *hs) {
return ssl_hs_error;
}

hs->new_session->secret.ResizeMaybeUninit(SSL3_MASTER_SECRET_SIZE);
hs->new_session->secret.ResizeForOverwrite(SSL3_MASTER_SECRET_SIZE);
if (!tls1_generate_master_secret(hs, MakeSpan(hs->new_session->secret),
pms)) {
return ssl_hs_error;
Expand Down Expand Up @@ -1850,7 +1848,7 @@ static enum ssl_hs_wait_t do_read_session_ticket(SSL_HANDSHAKE *hs) {

// Historically, OpenSSL filled in fake session IDs for ticket-based sessions.
// TODO(davidben): Are external callers relying on this? Try removing this.
hs->new_session->session_id.ResizeMaybeUninit(SHA256_DIGEST_LENGTH);
hs->new_session->session_id.ResizeForOverwrite(SHA256_DIGEST_LENGTH);
SHA256(CBS_data(&ticket), CBS_len(&ticket),
hs->new_session->session_id.data());

Expand Down
10 changes: 5 additions & 5 deletions ssl/handshake_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -941,7 +941,7 @@ static enum ssl_hs_wait_t do_select_parameters(SSL_HANDSHAKE *hs) {
// Assign a session ID if not using session tickets.
if (!hs->ticket_expected &&
(ssl->ctx->session_cache_mode & SSL_SESS_CACHE_SERVER)) {
hs->new_session->session_id.ResizeMaybeUninit(SSL3_SSL_SESSION_ID_LENGTH);
hs->new_session->session_id.ResizeForOverwrite(SSL3_SSL_SESSION_ID_LENGTH);
RAND_bytes(hs->new_session->session_id.data(),
hs->new_session->session_id.size());
}
Expand Down Expand Up @@ -1464,7 +1464,8 @@ static enum ssl_hs_wait_t do_read_client_key_exchange(SSL_HANDSHAKE *hs) {

// Allocate a buffer large enough for an RSA decryption.
Array<uint8_t> decrypt_buf;
if (!decrypt_buf.Init(EVP_PKEY_size(hs->credential->pubkey.get()))) {
if (!decrypt_buf.InitForOverwrite(
EVP_PKEY_size(hs->credential->pubkey.get()))) {
return ssl_hs_error;
}

Expand Down Expand Up @@ -1492,7 +1493,7 @@ static enum ssl_hs_wait_t do_read_client_key_exchange(SSL_HANDSHAKE *hs) {

// Prepare a random premaster, to be used on invalid padding. See RFC 5246,
// section 7.4.7.1.
if (!premaster_secret.Init(SSL_MAX_MASTER_KEY_LENGTH) ||
if (!premaster_secret.InitForOverwrite(SSL_MAX_MASTER_KEY_LENGTH) ||
!RAND_bytes(premaster_secret.data(), premaster_secret.size())) {
return ssl_hs_error;
}
Expand Down Expand Up @@ -1583,7 +1584,6 @@ static enum ssl_hs_wait_t do_read_client_key_exchange(SSL_HANDSHAKE *hs) {
if (!premaster_secret.Init(psk_len)) {
return ssl_hs_error;
}
OPENSSL_memset(premaster_secret.data(), 0, premaster_secret.size());
}

ScopedCBB new_premaster;
Expand All @@ -1605,7 +1605,7 @@ static enum ssl_hs_wait_t do_read_client_key_exchange(SSL_HANDSHAKE *hs) {
}

// Compute the master secret.
hs->new_session->secret.ResizeMaybeUninit(SSL3_MASTER_SECRET_SIZE);
hs->new_session->secret.ResizeForOverwrite(SSL3_MASTER_SECRET_SIZE);
if (!tls1_generate_master_secret(hs, MakeSpan(hs->new_session->secret),
premaster_secret)) {
return ssl_hs_error;
Expand Down
30 changes: 20 additions & 10 deletions ssl/internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -334,11 +334,21 @@ class Array {
}

// Init replaces the array with a newly-allocated array of |new_size|
// default-constructed copies of |T|. It returns true on success and false on
// error.
//
// Note that if |T| is a primitive type like |uint8_t|, it is uninitialized.
// value-constructed copies of |T|. It returns true on success and false on
// error. If |T| is a primitive type like |uint8_t|, value-construction means
// it will be zero-initialized.
bool Init(size_t new_size) {
if (!InitUninitialized(new_size)) {
return false;
}
cxx17_uninitialized_value_construct_n(data_, size_);
return true;
}

// InitForOverwrite behaves like |Init| but it default-constructs each element
// instead. This means that, if |T| is a primitive type, the array will be
// uninitialized and thus must be filled in by the caller.
bool InitForOverwrite(size_t new_size) {
if (!InitUninitialized(new_size)) {
return false;
}
Expand Down Expand Up @@ -585,10 +595,10 @@ class InplaceVector {
return true;
}

// TryResizeMaybeUninit behaves like |TryResize|, but newly-added elements are
// default-initialized, so POD types may contain uninitialized values that the
// caller is responsible for filling in.
bool TryResizeMaybeUninit(size_t new_size) {
// TryResizeForOverwrite behaves like |TryResize|, but newly-added elements
// are default-initialized, so POD types may contain uninitialized values that
// the caller is responsible for filling in.
bool TryResizeForOverwrite(size_t new_size) {
if (new_size <= size_) {
Shrink(new_size);
return true;
Expand Down Expand Up @@ -628,8 +638,8 @@ class InplaceVector {
// The following methods behave like their |Try*| counterparts, but abort the
// program on failure.
void Resize(size_t size) { BSSL_CHECK(TryResize(size)); }
void ResizeMaybeUninit(size_t size) {
BSSL_CHECK(TryResizeMaybeUninit(size));
void ResizeForOverwrite(size_t size) {
BSSL_CHECK(TryResizeForOverwrite(size));
}
void CopyFrom(Span<const T> in) { BSSL_CHECK(TryCopyFrom(in)); }
T &PushBack(T val) {
Expand Down
9 changes: 3 additions & 6 deletions ssl/ssl_cipher.cc
Original file line number Diff line number Diff line change
Expand Up @@ -919,7 +919,6 @@ static bool ssl_cipher_strength_sort(CIPHER_ORDER **head_p,
if (!number_uses.Init(max_strength_bits + 1)) {
return false;
}
OPENSSL_memset(number_uses.data(), 0, (max_strength_bits + 1) * sizeof(int));

// Now find the strength_bits values actually used.
curr = *head_p;
Expand Down Expand Up @@ -1231,7 +1230,7 @@ bool ssl_create_cipher_list(UniquePtr<SSLCipherPreferenceList> *out_cipher_list,
UniquePtr<STACK_OF(SSL_CIPHER)> cipherstack(sk_SSL_CIPHER_new_null());
Array<bool> in_group_flags;
if (cipherstack == nullptr ||
!in_group_flags.Init(OPENSSL_ARRAY_SIZE(kCiphers))) {
!in_group_flags.InitForOverwrite(OPENSSL_ARRAY_SIZE(kCiphers))) {
return false;
}

Expand All @@ -1246,13 +1245,11 @@ bool ssl_create_cipher_list(UniquePtr<SSLCipherPreferenceList> *out_cipher_list,
in_group_flags[num_in_group_flags++] = curr->in_group;
}
}
in_group_flags.Shrink(num_in_group_flags);

UniquePtr<SSLCipherPreferenceList> pref_list =
MakeUnique<SSLCipherPreferenceList>();
if (!pref_list ||
!pref_list->Init(
std::move(cipherstack),
MakeConstSpan(in_group_flags).subspan(0, num_in_group_flags))) {
if (!pref_list || !pref_list->Init(std::move(cipherstack), in_group_flags)) {
return false;
}

Expand Down
2 changes: 1 addition & 1 deletion ssl/ssl_credential.cc
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ bool ssl_get_credential_list(SSL_HANDSHAKE *hs, Array<SSL_CREDENTIAL *> *out) {
num_creds++;
}

if (!out->Init(num_creds)) {
if (!out->InitForOverwrite(num_creds)) {
return false;
}

Expand Down
11 changes: 10 additions & 1 deletion ssl/ssl_internal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,15 @@
BSSL_NAMESPACE_BEGIN
namespace {

TEST(ArrayTest, InitValueConstructs) {
Array<uint8_t> array;
ASSERT_TRUE(array.Init(10));
EXPECT_EQ(array.size(), 10u);
for (size_t i = 0; i < 10u; i++) {
EXPECT_EQ(0u, array[i]);
}
}

TEST(ArrayDeathTest, BoundsChecks) {
Array<int> array;
const int v[] = {1, 2, 3, 4};
Expand Down Expand Up @@ -345,7 +354,7 @@ TEST(InplaceVectorDeathTest, BoundsChecks) {
EXPECT_DEATH_IF_SUPPORTED(vec[1000], "");
// The vector cannot be resized past the capacity.
EXPECT_DEATH_IF_SUPPORTED(vec.Resize(5), "");
EXPECT_DEATH_IF_SUPPORTED(vec.ResizeMaybeUninit(5), "");
EXPECT_DEATH_IF_SUPPORTED(vec.ResizeForOverwrite(5), "");
int too_much_data[] = {1, 2, 3, 4, 5};
EXPECT_DEATH_IF_SUPPORTED(vec.CopyFrom(too_much_data), "");
vec.Resize(4);
Expand Down
14 changes: 8 additions & 6 deletions ssl/ssl_key_share.cc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class ECKeyShare : public SSLKeyShare {

// Encode the x-coordinate left-padded with zeros.
Array<uint8_t> secret;
if (!secret.Init((EC_GROUP_get_degree(group_) + 7) / 8) ||
if (!secret.InitForOverwrite((EC_GROUP_get_degree(group_) + 7) / 8) ||
!BN_bn2bin_padded(secret.data(), secret.size(), x.get())) {
return false;
}
Expand Down Expand Up @@ -162,7 +162,7 @@ class X25519KeyShare : public SSLKeyShare {
*out_alert = SSL_AD_INTERNAL_ERROR;

Array<uint8_t> secret;
if (!secret.Init(32)) {
if (!secret.InitForOverwrite(32)) {
return false;
}

Expand Down Expand Up @@ -220,7 +220,7 @@ class X25519Kyber768KeyShare : public SSLKeyShare {
bool Encap(CBB *out_ciphertext, Array<uint8_t> *out_secret,
uint8_t *out_alert, Span<const uint8_t> peer_key) override {
Array<uint8_t> secret;
if (!secret.Init(32 + KYBER_SHARED_SECRET_BYTES)) {
if (!secret.InitForOverwrite(32 + KYBER_SHARED_SECRET_BYTES)) {
return false;
}

Expand Down Expand Up @@ -260,7 +260,7 @@ class X25519Kyber768KeyShare : public SSLKeyShare {
*out_alert = SSL_AD_INTERNAL_ERROR;

Array<uint8_t> secret;
if (!secret.Init(32 + KYBER_SHARED_SECRET_BYTES)) {
if (!secret.InitForOverwrite(32 + KYBER_SHARED_SECRET_BYTES)) {
return false;
}

Expand Down Expand Up @@ -308,7 +308,8 @@ class X25519MLKEM768KeyShare : public SSLKeyShare {
bool Encap(CBB *out_ciphertext, Array<uint8_t> *out_secret,
uint8_t *out_alert, Span<const uint8_t> peer_key) override {
Array<uint8_t> secret;
if (!secret.Init(MLKEM_SHARED_SECRET_BYTES + X25519_SHARED_KEY_LEN)) {
if (!secret.InitForOverwrite(MLKEM_SHARED_SECRET_BYTES +
X25519_SHARED_KEY_LEN)) {
return false;
}

Expand Down Expand Up @@ -349,7 +350,8 @@ class X25519MLKEM768KeyShare : public SSLKeyShare {
*out_alert = SSL_AD_INTERNAL_ERROR;

Array<uint8_t> secret;
if (!secret.Init(MLKEM_SHARED_SECRET_BYTES + X25519_SHARED_KEY_LEN)) {
if (!secret.InitForOverwrite(MLKEM_SHARED_SECRET_BYTES +
X25519_SHARED_KEY_LEN)) {
return false;
}

Expand Down
Loading

0 comments on commit ce572d6

Please sign in to comment.