diff --git a/ns2dohd/ns2dohd.c b/ns2dohd/ns2dohd.c index 59a061c..dbf9a69 100644 --- a/ns2dohd/ns2dohd.c +++ b/ns2dohd/ns2dohd.c @@ -1346,7 +1346,11 @@ int main(int argc, char *argv[]) goto odoh_fail; } if (odoh_client_decrypt_response(&odoh_client, odoh_rep, (uint16_t)reply_len, - dns_rep, &dns_out_len) != 0) { + dns_rep, sizeof(dns_rep), &dns_out_len) != 0) { + stats.errors++; + goto odoh_fail; + } + if (dns_out_len > sizeof(dns_rep)) { stats.errors++; goto odoh_fail; } diff --git a/proxy/dohproxyd.c b/proxy/dohproxyd.c index 1e85930..1fc3e7b 100644 --- a/proxy/dohproxyd.c +++ b/proxy/dohproxyd.c @@ -85,6 +85,7 @@ enum req_type { struct target_conn { struct upstream up; int fd; + int require_public_target; WOLFSSL *ssl; nghttp2_session *session; struct forward_ctx *active_fx; @@ -100,6 +101,95 @@ struct client { struct client *next; }; +static int target_host_has_forbidden_syntax(const char *host) +{ + struct in_addr a4; + struct in6_addr a6; + size_t i; + + if (!host || *host == '\0') + return 1; + if (strchr(host, '@') || strchr(host, '/') || strchr(host, '\\') || + strchr(host, '?') || strchr(host, '#') || strchr(host, '%')) + return 1; + for (i = 0; host[i] != '\0'; i++) { + if (isspace((unsigned char)host[i])) + return 1; + } + if (strchr(host, ':') != NULL) + return inet_pton(AF_INET6, host, &a6) != 1; + if (inet_pton(AF_INET, host, &a4) == 1) + return 0; + for (i = 0; host[i] != '\0'; i++) { + unsigned char ch = (unsigned char)host[i]; + if (!(isalnum(ch) || ch == '.' || ch == '-')) + return 1; + } + return 0; +} + +static int sockaddr_is_public(const struct sockaddr *sa) +{ + if (sa->sa_family == AF_INET) { + const struct sockaddr_in *sin = (const struct sockaddr_in *)sa; + uint32_t ip = ntohl(sin->sin_addr.s_addr); + + if ((ip >> 24) == 10 || (ip >> 24) == 127 || (ip >> 24) == 0) + return 0; + if ((ip & 0xFFF00000U) == 0xAC100000U) + return 0; + if ((ip & 0xFFFF0000U) == 0xC0A80000U) + return 0; + if ((ip & 0xFFFF0000U) == 0xA9FE0000U) + return 0; + if ((ip & 0xF0000000U) == 0xE0000000U) + return 0; + return 1; + } + if (sa->sa_family == AF_INET6) { + const struct sockaddr_in6 *sin6 = (const struct sockaddr_in6 *)sa; + const uint8_t *ip = sin6->sin6_addr.s6_addr; + + if (IN6_IS_ADDR_LOOPBACK(&sin6->sin6_addr) || + IN6_IS_ADDR_UNSPECIFIED(&sin6->sin6_addr) || + IN6_IS_ADDR_MULTICAST(&sin6->sin6_addr)) + return 0; + if ((ip[0] & 0xFE) == 0xFC) + return 0; + if (ip[0] == 0xFE && (ip[1] & 0xC0) == 0x80) + return 0; + return 1; + } + return 0; +} + +static int target_is_allowed(const char *host, const char *port, const char *path) +{ + struct addrinfo hints, *res = NULL, *rp; + int allowed = 0; + + if (!path || path[0] != '/' || target_host_has_forbidden_syntax(host)) + return 0; + if (strcmp(port, "443") != 0) + return 0; + + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + hints.ai_socktype = SOCK_STREAM; + if (getaddrinfo(host, port, &hints, &res) != 0) + return 0; + + allowed = 1; + for (rp = res; rp; rp = rp->ai_next) { + if (!sockaddr_is_public(rp->ai_addr)) { + allowed = 0; + break; + } + } + freeaddrinfo(res); + return allowed; +} + static int lfd = -1; static WOLFSSL_CTX *srv_ctx = NULL; static WOLFSSL_CTX *cli_ctx = NULL; @@ -400,7 +490,7 @@ static int parse_target_from_path(const uint8_t *value, size_t len, return 0; } -static int tcp_connect(const char *host, const char *port) +static int tcp_connect(const char *host, const char *port, int require_public_target) { struct addrinfo hints, *res = NULL, *rp; int fd = -1; @@ -419,6 +509,8 @@ static int tcp_connect(const char *host, const char *port) tv.tv_usec = 0; for (rp = res; rp; rp = rp->ai_next) { + if (require_public_target && !sockaddr_is_public(rp->ai_addr)) + continue; fd = socket(rp->ai_family, rp->ai_socktype, rp->ai_protocol); if (fd < 0) continue; @@ -585,7 +677,7 @@ static int connect_target_connection(struct target_conn *tc) return 0; close_target_connection(tc); - tc->fd = tcp_connect(tc->up.host, tc->up.port); + tc->fd = tcp_connect(tc->up.host, tc->up.port, tc->require_public_target); if (tc->fd < 0) return -1; @@ -735,7 +827,10 @@ static int forward_to_dynamic_target(struct req *req, uint8_t *out, uint32_t *ou return -1; if (parse_url(full, &tc.up) != 0) return -1; + if (!target_is_allowed(tc.up.host, tc.up.port, tc.up.path)) + return -1; tc.fd = -1; + tc.require_public_target = 1; if (forward_to_upstream(&tc, req, "application/oblivious-dns-message", out, out_len) != 0) { close_target_connection(&tc); @@ -821,6 +916,11 @@ static int in_header_cb(nghttp2_session *session, if (frame->hd.type != NGHTTP2_HEADERS || frame->headers.cat != NGHTTP2_HCAT_REQUEST) return 0; + if (req->stream_id != 0 && frame->hd.stream_id != (int32_t)req->stream_id) { + nghttp2_submit_rst_stream(session, NGHTTP2_FLAG_NONE, + frame->hd.stream_id, NGHTTP2_REFUSED_STREAM); + return 0; + } if (frame->hd.stream_id != (int32_t)req->stream_id) req->stream_id = frame->hd.stream_id; @@ -944,18 +1044,21 @@ static int in_frame_recv_cb(nghttp2_session *session, } nghttp2_session_send(session); - free(req->resp); - memset(req, 0, sizeof(*req)); return 0; } static int in_stream_close_cb(nghttp2_session *session, int32_t stream_id, uint32_t error_code, void *user_data) { + struct client *cl = (struct client *)user_data; + struct req *req = cl ? &cl->req : NULL; + (void)session; - (void)stream_id; (void)error_code; - (void)user_data; + if (req && req->stream_id == (uint32_t)stream_id) { + free(req->resp); + memset(req, 0, sizeof(*req)); + } return 0; } @@ -982,7 +1085,7 @@ static void client_read(int fd, short revents, void *arg) } nghttp2_session_callbacks *cbs = NULL; - nghttp2_settings_entry iv[1] = {{ NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100 }}; + nghttp2_settings_entry iv[1] = {{ NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 1 }}; if (nghttp2_session_callbacks_new(&cbs) != 0) { free_client(cl); diff --git a/src/dohd.c b/src/dohd.c index 1b8b44d..76cb891 100644 --- a/src/dohd.c +++ b/src/dohd.c @@ -73,6 +73,21 @@ /* DNS request timeout in milliseconds */ #define DNS_REQUEST_TIMEOUT_MS 5000 +/* Maximum time (ms) a single HTTP/2 stream may stay open before its request is + * completed and forwarded upstream. Reaps half-open/stalled streams that open + * HEADERS but never send END_STREAM, bounding occupancy of the request pool. */ +#define H2_STREAM_TIMEOUT_MS 10000 + +/* Connection idle timeout (ms). A client connection with no read activity for + * this long is closed, reaping slow-loris / flow-control-stalled connections + * that would otherwise hold client and request pool slots open. */ +#define CLIENT_IDLE_TIMEOUT_MS 30000 + +/* Maximum decoded request header list size advertised to peers via + * SETTINGS_MAX_HEADER_LIST_SIZE. DoH requests carry only a handful of small + * headers; this caps HPACK-expanded header sets and is enforced by nghttp2. */ +#define H2_MAX_HEADER_LIST_SIZE (8 * 1024) + #define IP6_LOCALHOST { 0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1 } #define DOHD_REQ_MIN 20 #define STR_HTTP2_PREFACE "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n" @@ -265,6 +280,10 @@ static char *authorized_proxy_dir = NULL; static char *resolved_proxy_dir = NULL; static odoh_target_ctx odoh_target = {}; static proxy_auth_set proxy_set = {}; +static int signal_pipe[2] = { -1, -1 }; +static volatile sig_atomic_t reload_proxy_auth_pending = 0; +static volatile sig_atomic_t reload_odoh_target_pending = 0; +static volatile sig_atomic_t print_stats_pending = 0; #ifndef _MUSL_ void *sigset(int sig, void (*disp)(int)); #endif @@ -297,6 +316,7 @@ struct client_data { int h2; nghttp2_session *h2_session; int doh_sd; + evquick_timer *idle_timer; /* Connection idle timeout timer */ struct client_data *next; struct req_slot *list; }; @@ -304,6 +324,7 @@ struct client_data { static void dohd_reply(int fd, short __attribute__((unused)) revents, void *arg); static void dohd_destroy_request(struct req_slot *req); +static void dns_request_timeout(void *arg); /* Memory pools for client_data and req_slot */ static mempool_t *client_pool = NULL; @@ -436,13 +457,63 @@ static int reload_odoh_target(void) return 0; } +static void notify_signal_pipe(void) +{ + char marker = 's'; + + if (signal_pipe[1] >= 0) { + if (write(signal_pipe[1], &marker, 1) < 0) { + /* best-effort wakeup */ + } + } +} + static void sig_stats(int __attribute__((unused)) signo) { if (oblivion_mode) { - reload_proxy_authorization(); - reload_odoh_target(); + reload_proxy_auth_pending = 1; + reload_odoh_target_pending = 1; + } + print_stats_pending = 1; + notify_signal_pipe(); +} + +static void handle_pending_signals(int __attribute__((unused)) fd, + short __attribute__((unused)) revents, + void __attribute__((unused)) *arg) +{ + char drain[32]; + + if (signal_pipe[0] >= 0) { + while (read(signal_pipe[0], drain, sizeof(drain)) > 0) { + } + } + + if (reload_proxy_auth_pending) { + reload_proxy_auth_pending = 0; + if (reload_proxy_authorization() != 0) + dohprint(DOH_WARN, "Failed to reload authorized proxies"); } - printstats(); + if (reload_odoh_target_pending) { + reload_odoh_target_pending = 0; + if (reload_odoh_target() != 0) + dohprint(DOH_WARN, "Failed to reload ODoH target"); + } + if (print_stats_pending) { + print_stats_pending = 0; + printstats(); + } +} + +static int set_fd_nonblocking(int fd) +{ + int flags = fcntl(fd, F_GETFL, 0); + + if (flags < 0) + return -1; + if ((flags & O_NONBLOCK) != 0) + return 0; + return fcntl(fd, F_SETFL, flags | O_NONBLOCK); } @@ -485,6 +556,12 @@ static void dohd_destroy_client(struct client_data *cd) DOH_Stats.pending_requests--; rp = nxt; } + /* Cancel idle timeout timer if pending */ + if (cd->idle_timer) { + evquick_deltimer(cd->idle_timer); + cd->idle_timer = NULL; + } + /* Remove events from file desc */ if (cd->ev_doh) { evquick_delevent(cd->ev_doh); @@ -512,6 +589,33 @@ static void dohd_destroy_client(struct client_data *cd) check_stats(); } +/* Connection idle timeout - no read activity for CLIENT_IDLE_TIMEOUT_MS. + * Reaps slow-loris / flow-control-stalled connections that hold pool slots. */ +static void client_idle_timeout(void *arg) +{ + struct client_data *cd = (struct client_data *)arg; + /* One-shot timer has fired; clear stored pointer before destroying so + * dohd_destroy_client() does not attempt to cancel an already-freed timer. */ + cd->idle_timer = NULL; + if (!client_hash_exists(cd)) + return; + dohprint(DOH_DEBUG, "Closing idle client connection"); + DOH_Stats.socket_errors++; + dohd_destroy_client(cd); +} + +/* (Re)arm the per-connection idle timeout. Called at connection setup and on + * every read with activity, so the timer measures time since last activity. */ +static void client_arm_idle_timer(struct client_data *cd) +{ + if (cd->idle_timer) { + evquick_deltimer(cd->idle_timer); + cd->idle_timer = NULL; + } + cd->idle_timer = evquick_addtimer(CLIENT_IDLE_TIMEOUT_MS, 0, + client_idle_timeout, cd); +} + static void clean_exit(int __attribute__((unused)) signo) { /* Iterate hash table and destroy all clients */ @@ -535,6 +639,8 @@ static void clean_exit(int __attribute__((unused)) signo) mempool_free(request_pool, rp); rp = nxt; } + if (cd->idle_timer) + evquick_deltimer(cd->idle_timer); if (cd->ev_doh) evquick_delevent(cd->ev_doh); if (cd->ssl) @@ -554,6 +660,10 @@ static void clean_exit(int __attribute__((unused)) signo) odoh_target_free(&odoh_target); free(resolved_proxy_dir); resolved_proxy_dir = NULL; + if (signal_pipe[0] >= 0) + close(signal_pipe[0]); + if (signal_pipe[1] >= 0) + close(signal_pipe[1]); fprintf(stderr, "Cleanup, exiting...\n"); #ifdef DMALLOC @@ -578,13 +688,14 @@ struct req_slot *dns_create_request_h2(struct client_data *cd, uint32_t stream_i req = nghttp2_session_get_stream_user_data(cd->h2_session, stream_id); if (req) { dohprint(DOH_WARN, "W: request is not null for this stream id\n"); - + return req; } req = mempool_alloc(request_pool); if (req == NULL) { dohprint(DOH_ERR, "Request pool exhausted (capacity: %u)", MAX_REQUESTS); return req; } + memset(req, 0, sizeof(*req)); req->resolver = next_resolver(); req->resolver_sz = sizeof(struct sockaddr_in); /* Change AF / socksize if IPV6 */ @@ -620,11 +731,19 @@ struct req_slot *dns_create_request_h2(struct client_data *cd, uint32_t stream_i req->owner_fd = cd->doh_sd; req->h2_stream_id = stream_id; req->timeout_timer = NULL; + req->ev_dns = NULL; req->is_odoh = 0; req->content_type_seen = 0; req->is_h2_get = 0; memset(&req->odoh_ctx, 0, sizeof(req->odoh_ctx)); nghttp2_session_set_stream_user_data(cd->h2_session, stream_id, req); + + /* Arm a stream timeout to reap half-open streams that never complete their + * request (no END_STREAM). This is re-armed with the upstream timeout once + * the request is actually forwarded in dns_send_request_h2(). */ + req->timeout_timer = evquick_addtimer(H2_STREAM_TIMEOUT_MS, 0, + dns_request_timeout, req); + return req; } @@ -671,7 +790,7 @@ static int dns_send_request_h2(struct req_slot *req) return -1; if (odoh_target_decrypt_query(&odoh_target, req->h2_request_buffer, (uint16_t)req->h2_request_len, - plain_dns, &plain_len, &req->odoh_ctx) != 0) { + plain_dns, sizeof(plain_dns), &plain_len, &req->odoh_ctx) != 0) { return -1; } if (plain_len > sizeof(req->h2_request_buffer)) @@ -703,7 +822,12 @@ static int dns_send_request_h2(struct req_slot *req) return -1; } - /* Start timeout timer - if upstream doesn't respond, return 504 */ + /* Cancel the stream-open timeout armed at request creation, then start the + * upstream timeout - if upstream doesn't respond, return 504 */ + if (req->timeout_timer) { + evquick_deltimer(req->timeout_timer); + req->timeout_timer = NULL; + } req->timeout_timer = evquick_addtimer(DNS_REQUEST_TIMEOUT_MS, 0, dns_request_timeout, req); @@ -726,71 +850,54 @@ static ssize_t client_ssl_write(struct client_data *cd, const void *data, size_t */ static int dns_skip_question(uint8_t **record, int maxlen) { - int skip = 0; - size_t len = (size_t)maxlen; - uint8_t *cur = *record; - int consumed = 0; - while (len > 0) { + uint8_t *start = *record; + uint8_t *cur = start; + const uint8_t *buf_end = start + (size_t)maxlen; + + while (cur < buf_end) { uint8_t c = *cur; if ((c & 0xC0) == 0xC0) { - if (len < 2) + if ((size_t)(buf_end - cur) < 2) return -1; cur += 2; - consumed += 2; - len -= 2; break; } if (c == 0) { cur += 1; - consumed += 1; - len -= 1; break; } - if (c > 63 || len < (size_t)c + 1) + if (c > 63 || (size_t)(buf_end - cur) < (size_t)c + 1) return -1; cur += c + 1; - consumed += c + 1; - len -= c + 1; } - if (len < DNSQ_SUFFIX_LEN) + if ((size_t)(buf_end - cur) < DNSQ_SUFFIX_LEN) return -1; cur += DNSQ_SUFFIX_LEN; - consumed += DNSQ_SUFFIX_LEN; *record = cur; - skip = consumed; - return skip; + return (int)(cur - start); } -static int dns_skip_rr_name(uint8_t **record, size_t *len) +static int dns_skip_rr_name(uint8_t **record, const uint8_t *buf_end) { uint8_t *cur = *record; - size_t remain = *len; - int consumed = 0; - while (remain > 0) { + + while (cur < buf_end) { uint8_t c = *cur; if ((c & 0xC0) == 0xC0) { - if (remain < 2) + if ((size_t)(buf_end - cur) < 2) return -1; cur += 2; - consumed += 2; - remain -= 2; *record = cur; - *len = remain; - return consumed; + return 2; } if (c == 0) { cur += 1; - consumed += 1; - remain -= 1; *record = cur; - *len = remain; - return consumed; + return 1; } - if (c > 63 || remain < (size_t)c + 1) + if (c > 63 || (size_t)(buf_end - cur) < (size_t)c + 1) return -1; cur += c + 1; - consumed += c + 1; - remain -= c + 1; } return -1; } @@ -799,26 +906,38 @@ static uint32_t dnsreply_min_age(const void *p, size_t len) { int i = 0; const struct dns_header *hdr = p; - uint8_t *record = ((uint8_t *)p + sizeof(struct dns_header)); - int skip = 0; - int answers = ntohs(hdr->ancount) + ntohs(hdr->nscount) + ntohs(hdr->arcount); + const uint8_t *buf = p; + const uint8_t *buf_end; + uint8_t *record; + int answers; uint32_t min_ttl = 3600; + + if (len < sizeof(struct dns_header)) + return min_ttl; + + record = (uint8_t *)buf + sizeof(struct dns_header); + buf_end = buf + len; + answers = ntohs(hdr->ancount) + ntohs(hdr->nscount) + ntohs(hdr->arcount); if (answers < 1) return -1; for (i = 0; i < ntohs(hdr->qdcount); i++) { - skip = dns_skip_question(&record, len); - if (skip < DNSQ_SUFFIX_LEN) { + if (record > buf_end) + return min_ttl; + if (dns_skip_question(&record, (int)(buf_end - record)) < DNSQ_SUFFIX_LEN) return min_ttl; - } - len -= skip; } for (i = 0; i < answers; i++) { uint32_t ttl; uint32_t datalen; - if (dns_skip_rr_name(&record, &len) < 0) + size_t remain; + + if (record > buf_end) return min_ttl; - if (len < 10) + if (dns_skip_rr_name(&record, buf_end) < 0) + return min_ttl; + remain = (size_t)(buf_end - record); + if (remain < 10) return min_ttl; ttl = (record[4] << 24 ) + (record[5] << 16 ) + @@ -826,18 +945,31 @@ static uint32_t dnsreply_min_age(const void *p, size_t len) record[7]; datalen = (record[8] << 8) + record[9]; - if (len < (10U + datalen)) + if (remain < (10U + datalen)) return min_ttl; if (ttl && (ttl < min_ttl)) min_ttl = ttl; record += 10 + datalen; - len -= 10 + datalen; } return min_ttl; } #define DOHD_MAX_REPLY (DNS_BUFFER_MAXSIZE) +static void dns_request_finish_upstream(struct req_slot *req) +{ + if (!req) + return; + if (req->ev_dns) { + evquick_delevent(req->ev_dns); + req->ev_dns = NULL; + } + if (req->dns_sd >= 0) { + close(req->dns_sd); + req->dns_sd = -1; + } +} + static void dohd_destroy_request(struct req_slot *req) { struct client_data *cd; @@ -868,18 +1000,16 @@ static void dohd_destroy_request(struct req_slot *req) evquick_deltimer(req->timeout_timer); req->timeout_timer = NULL; } - if (req->ev_dns) { - evquick_delevent(req->ev_dns); - req->ev_dns = NULL; - } - close(req->dns_sd); + dns_request_finish_upstream(req); if (req->h2_response_data) { DOH_Stats.mem -= req->h2_response_len; free(req->h2_response_data); req->h2_response_data = NULL; } - if (client_valid && cd->h2_session && req->h2_stream_id) { + if (client_valid && cd->h2_session && req->h2_stream_id && + nghttp2_session_get_stream_user_data(cd->h2_session, + req->h2_stream_id) == req) { nghttp2_session_set_stream_user_data(cd->h2_session, req->h2_stream_id, NULL); } @@ -993,6 +1123,7 @@ static void dohd_reply(int fd, short __attribute__((unused)) revents, DOH_Stats.mem += resp_len; memcpy(req->h2_response_data, resp_ptr, resp_len); req->h2_response_len = resp_len; + dns_request_finish_upstream(req); memset(&data_prd, 0, sizeof(data_prd)); data_prd.source.ptr = req; data_prd.read_callback = h2_cb_req_submit; @@ -1172,11 +1303,12 @@ static int h2_cb_on_header(nghttp2_session *session, memcpy(b64tmp, value + strlen(GETDNS), b64len); b64tmp[b64len] = '\0'; req->h2_request_len = 0; - if(dohd_url64_check(b64tmp) == 0) { + if(dohd_url64_check(b64tmp, b64len) == 0) { dohd_destroy_request(req); return 0; } - outlen = dohd_url64_decode(b64tmp, req->h2_request_buffer); + outlen = dohd_url64_decode(b64tmp, b64len, + req->h2_request_buffer, sizeof(req->h2_request_buffer)); if (outlen <= 0) { dohd_destroy_request(req); return 0; @@ -1218,6 +1350,8 @@ static void tls_read(__attribute__((unused)) int fd, short __attribute__((unused /* Verify client still exists in hash table */ if (!client_hash_exists(cd)) return; + /* Read activity on this connection: reset the idle timeout */ + client_arm_idle_timer(cd); if (!cd->tls_handshake_done) { /* Establish TLS connection */ ret = wolfSSL_accept(cd->ssl); @@ -1236,8 +1370,9 @@ static void tls_read(__attribute__((unused)) int fd, short __attribute__((unused } if (wolfSSL_ALPN_GetProtocol(cd->ssl, &proto, &proto_len) && (2 == proto_len) && strncmp(proto, "h2", 2) == 0) { - nghttp2_settings_entry iv[1] = { - {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100} + nghttp2_settings_entry iv[2] = { + {NGHTTP2_SETTINGS_MAX_CONCURRENT_STREAMS, 100}, + {NGHTTP2_SETTINGS_MAX_HEADER_LIST_SIZE, H2_MAX_HEADER_LIST_SIZE} }; nghttp2_session_callbacks *h2_cbs; @@ -1251,7 +1386,7 @@ static void tls_read(__attribute__((unused)) int fd, short __attribute__((unused nghttp2_session_server_new(&cd->h2_session, h2_cbs, cd); nghttp2_session_callbacks_del(h2_cbs); cd->h2 = 1; - nghttp2_submit_settings(cd->h2_session, NGHTTP2_FLAG_NONE, iv, 1); + nghttp2_submit_settings(cd->h2_session, NGHTTP2_FLAG_NONE, iv, 2); } cd->tls_handshake_done = 1; } @@ -1385,6 +1520,9 @@ static void dohd_new_connection(int __attribute__((unused)) fd, /* Insert into hash table - O(1) */ client_hash_insert(cd); + /* Arm the connection idle timeout */ + client_arm_idle_timer(cd); + DOH_Stats.mem += sizeof(struct client_data); DOH_Stats.clients++; check_stats(); @@ -1662,6 +1800,16 @@ int main(int argc, char *argv[]) /* Initialize libevquick */ evquick_init(); + if (pipe(signal_pipe) != 0) { + dohprint(DOH_ERR, "ERROR: failed to create signal pipe"); + return -1; + } + if (set_fd_nonblocking(signal_pipe[0]) != 0 || + set_fd_nonblocking(signal_pipe[1]) != 0) { + dohprint(DOH_ERR, "ERROR: failed to make signal pipe non-blocking"); + return -1; + } + evquick_addevent(signal_pipe[0], EVQUICK_EV_READ, handle_pending_signals, NULL, NULL); /* Initialize memory pools */ client_pool = mempool_create(sizeof(struct client_data), MAX_CLIENTS); diff --git a/src/heap.h b/src/heap.h index fd3f0a8..781703d 100644 --- a/src/heap.h +++ b/src/heap.h @@ -16,17 +16,17 @@ #define DECLARE_HEAP(type, orderby) \ struct heap_element_##type { \ - uint32_t id; \ + uint64_t id; \ type data; \ }; \ struct heap_##type { \ uint32_t size; \ uint32_t n; \ - uint32_t last_id; \ + uint64_t last_id; \ struct heap_element_##type *top; \ }; \ typedef struct heap_##type heap_##type; \ -static inline int heap_insert(struct heap_##type *heap, type *el) \ +static inline uint64_t heap_insert(struct heap_##type *heap, type *el) \ { \ int i; \ struct heap_element_##type etmp; \ @@ -36,24 +36,24 @@ static inline int heap_insert(struct heap_##type *heap, type *el) (heap->n + 1) * sizeof(struct heap_element_##type)); \ if (!_tmp) { \ heap->n--; \ - return -1; \ + return UINT64_MAX; \ } \ heap->top = _tmp; \ heap->size++; \ } \ + if (heap->last_id == UINT64_MAX) \ + heap->last_id = 0; \ etmp.id = heap->last_id++; \ - if ((heap->last_id & 0x80000000U) != 0) \ - heap->last_id = 0; /* Wrap around */ \ if (heap->n == 1) { \ memcpy(&heap->top[1], &etmp, sizeof(struct heap_element_##type)); \ - return (int)etmp.id; \ + return etmp.id; \ } \ for (i = heap->n; ((i > 1) && \ (heap->top[i / 2].data.orderby > el->orderby)); i /= 2) { \ memcpy(&heap->top[i], &heap->top[i / 2], sizeof(struct heap_element_##type)); \ } \ memcpy(&heap->top[i], &etmp, sizeof(struct heap_element_##type)); \ - return (int)etmp.id; \ + return etmp.id; \ } \ static inline int heap_peek(struct heap_##type *heap, type *first) \ { \ @@ -81,7 +81,7 @@ static inline int heap_peek(struct heap_##type *heap, type *first) memcpy(&heap->top[i], last, sizeof(struct heap_element_##type)); \ return 0; \ } \ -static inline int heap_delete(struct heap_##type *heap, int id) \ +static inline int heap_delete(struct heap_##type *heap, uint64_t id) \ { \ int found = 0; \ int i; \ @@ -136,4 +136,3 @@ static inline void heap_destroy(heap_##type *h) free(h->top); \ free(h); \ } - diff --git a/src/libevquick.c b/src/libevquick.c index e3ee92d..b6aeecc 100644 --- a/src/libevquick.c +++ b/src/libevquick.c @@ -61,7 +61,7 @@ struct evquick_event struct evquick_timer { unsigned long long interval; - int id; + uint64_t id; short flags; #ifdef EVQUICK_PTHREAD void (*callback)(CTX ctx, void *arg); @@ -74,7 +74,7 @@ struct evquick_timer struct evquick_timer_instance { unsigned long long expire; - int id; + uint64_t id; struct evquick_timer *ev_timer; }; typedef struct evquick_timer_instance evquick_timer_instance; diff --git a/src/odoh.c b/src/odoh.c index 6fc8e11..c57932d 100644 --- a/src/odoh.c +++ b/src/odoh.c @@ -307,7 +307,7 @@ static int build_plaintext(const uint8_t *dns, uint16_t dns_len, } static int parse_plaintext_dns(const uint8_t *plain, uint16_t plain_len, - uint8_t *dns_out, uint16_t *dns_out_len) + uint8_t *dns_out, uint16_t dns_out_cap, uint16_t *dns_out_len) { uint16_t dns_len; uint16_t pad_len; @@ -323,6 +323,8 @@ static int parse_plaintext_dns(const uint8_t *plain, uint16_t plain_len, pad_len = be16(plain + 2 + dns_len); if ((size_t)(2 + dns_len + 2 + pad_len) > plain_len) return -1; + if (dns_len > dns_out_cap) + return -1; for (i = 0; i < pad_len; i++) { if (plain[2 + dns_len + 2 + i] != 0) @@ -520,11 +522,12 @@ int odoh_client_encrypt_query(const odoh_config *cfg, int odoh_target_decrypt_query(odoh_target_ctx *target, const uint8_t *in, uint16_t in_len, - uint8_t *dns_out, uint16_t *dns_out_len, + uint8_t *dns_out, uint16_t dns_out_cap, uint16_t *dns_out_len, odoh_req_ctx *req_ctx) { #if !ODOH_HAVE_HPKE_CONTEXT_API - (void)target; (void)in; (void)in_len; (void)dns_out; (void)dns_out_len; (void)req_ctx; + (void)target; (void)in; (void)in_len; (void)dns_out; (void)dns_out_cap; + (void)dns_out_len; (void)req_ctx; return -1; #else odoh_message_view msg; @@ -561,6 +564,8 @@ int odoh_target_decrypt_query(odoh_target_ctx *target, enc = msg.encrypted; ct = msg.encrypted + enc_len; ct_len = (uint16_t)(msg.encrypted_len - enc_len); + if (ct_len < req_ctx->hpke.Nt) + return -1; if (odoh_hpke_init_open_context(&req_ctx->hpke, &req_ctx->hpke_ctx, &target->priv, enc, enc_len, @@ -574,7 +579,8 @@ int odoh_target_decrypt_query(odoh_target_ctx *target, aad, aad_len, (byte *)ct, ct_len, plain) != 0) return -1; - if (parse_plaintext_dns(plain, (uint16_t)(ct_len - req_ctx->hpke.Nt), dns_out, dns_out_len) != 0) + if (parse_plaintext_dns(plain, (uint16_t)(ct_len - req_ctx->hpke.Nt), + dns_out, dns_out_cap, dns_out_len) != 0) return -1; memcpy(req_ctx->q_plain, plain, (size_t)(ct_len - req_ctx->hpke.Nt)); @@ -669,10 +675,11 @@ int odoh_target_encrypt_response(const odoh_req_ctx *req_ctx, int odoh_client_decrypt_response(odoh_client_ctx *client_ctx, const uint8_t *in, uint16_t in_len, - uint8_t *dns_out, uint16_t *dns_out_len) + uint8_t *dns_out, uint16_t dns_out_cap, uint16_t *dns_out_len) { #if !ODOH_HAVE_HPKE_CONTEXT_API - (void)client_ctx; (void)in; (void)in_len; (void)dns_out; (void)dns_out_len; + (void)client_ctx; (void)in; (void)in_len; (void)dns_out; (void)dns_out_cap; + (void)dns_out_len; return -1; #else odoh_message_view msg; @@ -721,7 +728,7 @@ int odoh_client_decrypt_response(odoh_client_ctx *client_ctx, msg.encrypted + pt_len, client_ctx->hpke.Nt, aad, aad_len); if (ret == 0) - ret = parse_plaintext_dns(plain, pt_len, dns_out, dns_out_len); + ret = parse_plaintext_dns(plain, pt_len, dns_out, dns_out_cap, dns_out_len); } wc_AesFree(&aes); diff --git a/src/odoh.h b/src/odoh.h index fc6ca49..b3b94c1 100644 --- a/src/odoh.h +++ b/src/odoh.h @@ -64,11 +64,11 @@ int odoh_client_encrypt_query(const odoh_config *cfg, int odoh_client_decrypt_response(odoh_client_ctx *client_ctx, const uint8_t *in, uint16_t in_len, - uint8_t *dns_out, uint16_t *dns_out_len); + uint8_t *dns_out, uint16_t dns_out_cap, uint16_t *dns_out_len); int odoh_target_decrypt_query(odoh_target_ctx *target, const uint8_t *in, uint16_t in_len, - uint8_t *dns_out, uint16_t *dns_out_len, + uint8_t *dns_out, uint16_t dns_out_cap, uint16_t *dns_out_len, odoh_req_ctx *req_ctx); int odoh_target_encrypt_response(const odoh_req_ctx *req_ctx, diff --git a/src/url64.c b/src/url64.c index 3005a89..c09b067 100644 --- a/src/url64.c +++ b/src/url64.c @@ -19,6 +19,8 @@ */ #include +#include +#include static const unsigned char asciitable[256] = { 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, 64, @@ -42,48 +44,76 @@ static const unsigned char asciitable[256] = { // returns an estimation of the length of the data once decoded int dohd_url64_declen(int len) { return ((len + 3) >> 2) * 3; } -// assumes null terminated string -// no padding equals check (no modulo 4) -// returns 0 if not base else length of base encoded string -int dohd_url64_check(const char *in) { - if(!in) { return 0; } - register int c; - unsigned char *bufin; - bufin = (unsigned char *)in; - for(c=0; bufin[c] != '\0'; c++) - if(asciitable[*(bufin+c)] > 63) +static size_t dohd_url64_declen_size(size_t len) +{ + return ((len + 3U) >> 2) * 3U; +} + +int dohd_url64_check(const char *in, size_t in_len) { + size_t c; + const unsigned char *bufin; + + if (!in || in_len > (size_t)INT_MAX) + return 0; + bufin = (const unsigned char *)in; + for (c = 0; c < in_len; c++) { + if (asciitable[bufin[c]] > 63) return 0; - return(c); + } + return (int)in_len; } -int dohd_url64_decode(const char *src, uint8_t *dest) { - register const unsigned char *bufin; - register unsigned char *bufout; - register int nprbytes; +int dohd_url64_decode(const char *src, size_t src_len, uint8_t *dest, size_t dest_cap) { + const unsigned char *bufin; + unsigned char *bufout; + size_t nprbytes; const unsigned char *_buf = (const unsigned char *) src; - bufin = _buf; - while (asciitable[*(bufin++)] <= 63); - nprbytes = bufin - _buf - 1; + size_t out_cap; + if (!src || !dest || dest_cap == 0) + return -1; + if (src_len > (size_t)INT_MAX) + return -1; + if (dohd_url64_check(src, src_len) != (int)src_len) + return -1; + if (dohd_url64_declen_size(src_len) + 1 > dest_cap) + return -1; + + bufin = _buf; + nprbytes = src_len; + out_cap = dest_cap - 1; bufout = (unsigned char *) dest; bufin = _buf; while (nprbytes > 4) { + if (out_cap < 3) + return -1; *(bufout++) = (unsigned char) (asciitable[*bufin] << 2 | asciitable[bufin[1]] >> 4); *(bufout++) = (unsigned char) (asciitable[bufin[1]] << 4 | asciitable[bufin[2]] >> 2); *(bufout++) = (unsigned char) (asciitable[bufin[2]] << 6 | asciitable[bufin[3]]); bufin += 4; nprbytes -= 4; + out_cap -= 3; } - if (nprbytes > 1) + if (nprbytes > 1) { + if (out_cap < 1) + return -1; *(bufout++) = (unsigned char) (asciitable[*bufin] << 2 | asciitable[bufin[1]] >> 4); - if (nprbytes > 2) + out_cap--; + } + if (nprbytes > 2) { + if (out_cap < 1) + return -1; *(bufout++) = (unsigned char) (asciitable[bufin[1]] << 4 | asciitable[bufin[2]] >> 2); - if (nprbytes > 3) + out_cap--; + } + if (nprbytes > 3) { + if (out_cap < 1) + return -1; *(bufout++) = (unsigned char) (asciitable[bufin[2]] << 6 | asciitable[bufin[3]]); + } *(bufout++) = '\0'; - // return the length of decoded - return(bufout-(unsigned char*)dest-1); + return (int)(bufout - (unsigned char *)dest - 1); } diff --git a/src/url64.h b/src/url64.h index f7110da..1ad073e 100644 --- a/src/url64.h +++ b/src/url64.h @@ -19,7 +19,10 @@ #ifndef URL64_H_INCLUDED #define URL64_H_INCLUDED +#include +#include + int dohd_url64_declen(int len); -int dohd_url64_decode(const char *src, uint8_t *dest); -int dohd_url64_check(const char *in); +int dohd_url64_decode(const char *src, size_t src_len, uint8_t *dest, size_t dest_cap); +int dohd_url64_check(const char *in, size_t in_len); #endif diff --git a/test/test_dns_parser.c b/test/test_dns_parser.c index 0bbdc6b..18ab8db 100644 --- a/test/test_dns_parser.c +++ b/test/test_dns_parser.c @@ -57,66 +57,62 @@ struct __attribute__((packed)) dns_header { /* Skip a DNS question section entry */ static int dns_skip_question(uint8_t **record, int maxlen) { - int len = 0; - uint8_t *r = *record; - - while (*r != 0) { - if (*r > 63) { - /* Compression pointer - 2 bytes */ - len += 2; + uint8_t *start = *record; + uint8_t *r = start; + const uint8_t *end = start + (size_t)maxlen; + + while (r < end) { + uint8_t c = *r; + if ((c & 0xC0) == 0xC0) { + if ((size_t)(end - r) < 2) + return -1; r += 2; break; } - len += *r + 1; - r += *r + 1; - if (len > maxlen) + if (c == 0) { + r++; + break; + } + if (c > 63 || (size_t)(end - r) < (size_t)c + 1) return -1; + r += c + 1; } - if (*r == 0) { - len++; - r++; - } - /* Skip QTYPE and QCLASS (4 bytes) */ - len += 4; - r += 4; - - if (len > maxlen) + if ((size_t)(end - r) < 4) return -1; - + r += 4; *record = r; - return len; + return (int)(r - start); } /* Skip RR name (handles compression) */ -static int dns_skip_rr_name(uint8_t **record, size_t *len) { +static int dns_skip_rr_name(uint8_t **record, const uint8_t *end) { uint8_t *r = *record; - size_t consumed = 0; - - while (*r != 0) { - if (*r >= 0xC0) { - /* Compression pointer */ - consumed += 2; + + while (r < end) { + uint8_t c = *r; + if ((c & 0xC0) == 0xC0) { + if ((size_t)(end - r) < 2) + return -1; r += 2; *record = r; - *len -= consumed; return 0; } - consumed += *r + 1; - r += *r + 1; - if (consumed > *len) + if (c == 0) { + r++; + *record = r; + return 0; + } + if (c > 63 || (size_t)(end - r) < (size_t)c + 1) return -1; + r += c + 1; } - /* Skip null terminator */ - consumed++; - r++; - *record = r; - *len -= consumed; - return 0; + return -1; } /* Extract minimum TTL from DNS response */ static uint32_t dnsreply_min_age(const void *p, size_t len) { struct dns_header *hdr = (struct dns_header *)p; + const uint8_t *end = (const uint8_t *)p + len; uint8_t *record; uint32_t min_ttl = 0xFFFFFFFF; uint16_t qdcount, ancount, nscount, arcount; @@ -136,31 +132,38 @@ static uint32_t dnsreply_min_age(const void *p, size_t len) { /* Skip questions */ for (i = 0; i < qdcount; i++) { - if (dns_skip_question(&record, len) < 0) + if (dns_skip_question(&record, (int)(end - record)) < 0) return 0; } /* Process answer, authority, and additional sections */ int total_rr = ancount + nscount + arcount; - for (i = 0; i < total_rr && len > 10; i++) { + for (i = 0; i < total_rr; i++) { uint32_t ttl; uint16_t datalen; + uint32_t ttl_net; + uint16_t datalen_net; + size_t remain; - if (dns_skip_rr_name(&record, &len) < 0) + if (dns_skip_rr_name(&record, end) < 0) return min_ttl; - if (len < 10) + remain = (size_t)(end - record); + if (remain < 10) return min_ttl; /* TYPE (2) + CLASS (2) + TTL (4) + RDLENGTH (2) = 10 bytes */ - ttl = ntohl(*(uint32_t *)(record + 4)); - datalen = ntohs(*(uint16_t *)(record + 8)); + memcpy(&ttl_net, record + 4, sizeof(ttl_net)); + memcpy(&datalen_net, record + 8, sizeof(datalen_net)); + ttl = ntohl(ttl_net); + datalen = ntohs(datalen_net); + if (remain < (size_t)(10 + datalen)) + return min_ttl; if (ttl < min_ttl && ttl > 0) min_ttl = ttl; record += 10 + datalen; - len -= 10 + datalen; } return (min_ttl == 0xFFFFFFFF) ? 0 : min_ttl; @@ -291,6 +294,40 @@ static int test_dns_truncated(void) { return 1; } +static int test_dns_truncated_question_regression(void) { + uint8_t response[] = { + 0x12, 0x34, + 0x81, 0x80, + 0x00, 0x01, + 0x00, 0x01, + 0x00, 0x00, + 0x00, 0x00, + 0x3f, 'a', 'a', 'a', 'a' + }; + + TEST_ASSERT(dnsreply_min_age(response, sizeof(response)) == 0, + "truncated question returns 0 without over-reading"); + return 1; +} + +static int test_dns_truncated_answer_regression(void) { + uint8_t response[] = { + 0x12, 0x34, + 0x81, 0x80, + 0x00, 0x00, + 0x00, 0x01, + 0x00, 0x00, + 0x00, 0x00, + 0x00, + 0x00, 0x01, + 0x00 + }; + + TEST_ASSERT(dnsreply_min_age(response, sizeof(response)) == 0xFFFFFFFF, + "truncated answer returns sentinel without over-reading"); + return 1; +} + /* Test: Empty response (no answers) */ static int test_dns_no_answers(void) { uint8_t response[] = { @@ -365,6 +402,8 @@ int main(int argc, char **argv) { test_dns_ttl_extraction(); test_dns_multiple_ttls(); test_dns_truncated(); + test_dns_truncated_question_regression(); + test_dns_truncated_answer_regression(); test_dns_no_answers(); test_dns_long_name(); test_dns_compression(); diff --git a/test/test_heap.c b/test/test_heap.c index 8d2b6f5..7d788de 100644 --- a/test/test_heap.c +++ b/test/test_heap.c @@ -87,8 +87,8 @@ static int test_heap_single_insert(void) { test_timer t = { .expire = 100, .value = 42 }; test_timer out; - int id = heap_insert(h, &t); - TEST_ASSERT(id >= 0, "heap_insert returns valid id"); + uint64_t id = heap_insert(h, &t); + TEST_ASSERT(id != UINT64_MAX, "heap_insert returns valid id"); TEST_ASSERT(h->n == 1, "heap has 1 element after insert"); test_timer *first = heap_first(h); @@ -139,12 +139,12 @@ static int test_heap_ordering(void) { /* Test heap delete by id */ static int test_heap_delete(void) { heap_test_timer *h = heap_init(); - test_timer t, out; - int id1, id2, id3; + test_timer t; + uint64_t id2; - t.expire = 100; t.value = 1; id1 = heap_insert(h, &t); + t.expire = 100; t.value = 1; heap_insert(h, &t); t.expire = 200; t.value = 2; id2 = heap_insert(h, &t); - t.expire = 300; t.value = 3; id3 = heap_insert(h, &t); + t.expire = 300; t.value = 3; heap_insert(h, &t); TEST_ASSERT(h->n == 3, "heap has 3 elements"); @@ -227,8 +227,8 @@ static int test_heap_growth(void) { for (i = 0; i < 100; i++) { t.expire = i; t.value = i; - int id = heap_insert(h, &t); - TEST_ASSERT(id >= 0, "insert during growth succeeds"); + uint64_t id = heap_insert(h, &t); + TEST_ASSERT(id != UINT64_MAX, "insert during growth succeeds"); } TEST_ASSERT(h->n == 100, "heap has 100 elements"); TEST_ASSERT(h->size >= 100, "heap size grew appropriately"); @@ -243,7 +243,7 @@ static int test_heap_id_wrap(void) { test_timer t, out; /* Force id near wraparound point */ - h->last_id = 0x7FFFFFF0; + h->last_id = 0xFFFFFFFFFFFFFFF0ULL; for (int i = 0; i < 20; i++) { t.expire = i; @@ -262,6 +262,20 @@ static int test_heap_id_wrap(void) { return 1; } +static int test_heap_skips_error_sentinel(void) { + heap_test_timer *h = heap_init(); + test_timer t = { .expire = 1, .value = 1 }; + uint64_t id; + + h->last_id = UINT64_MAX; + id = heap_insert(h, &t); + TEST_ASSERT(id == 0, "heap_insert skips UINT64_MAX sentinel"); + TEST_ASSERT(h->last_id == 1, "heap_insert advances after sentinel wrap"); + + heap_destroy(h); + return 1; +} + int main(int argc, char **argv) { (void)argc; (void)argv; @@ -279,6 +293,7 @@ int main(int argc, char **argv) { test_heap_stress(); test_heap_growth(); test_heap_id_wrap(); + test_heap_skips_error_sentinel(); fprintf(stderr, "\n=== Results: %d/%d tests passed ===\n", tests_passed, tests_run); diff --git a/test/test_url64.c b/test/test_url64.c index d4f5ecf..0b68ceb 100644 --- a/test/test_url64.c +++ b/test/test_url64.c @@ -22,8 +22,7 @@ #include #include #include - -extern int dohd_url64_decode(const char *src, uint8_t *dest); +#include "../src/url64.h" static const int32_t hextable[] = { -1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1,-1, @@ -63,13 +62,13 @@ static void buf2hex(char *dst, const uint8_t *buf, const size_t len) { } static int test_compare(const char *url64, const char *hex) { - uint8_t out_bin[128]; - char out_hex[256]; + uint8_t out_bin[256]; + char out_hex[512]; int check_len = strlen(hex) / 2; fprintf(stderr,"%u %s\n",check_len, hex); char *check_bin = calloc(check_len+1, 1); hex2buf(check_bin, hex); - int out_len = dohd_url64_decode(url64, out_bin); + int out_len = dohd_url64_decode(url64, strlen(url64), out_bin, sizeof(out_bin)); if(out_len != check_len) return(0); if(memcmp(check_bin, out_bin, check_len) !=0) return(0); buf2hex(out_hex, out_bin, out_len); diff --git a/test/test_url64_extended.c b/test/test_url64_extended.c index 003c151..47ced99 100644 --- a/test/test_url64_extended.c +++ b/test/test_url64_extended.c @@ -24,10 +24,7 @@ #include #include #include - -extern int dohd_url64_decode(const char *src, uint8_t *dest); -extern int dohd_url64_check(const char *in); -extern int dohd_url64_declen(int len); +#include "../src/url64.h" static int tests_run = 0; static int tests_passed = 0; @@ -44,15 +41,14 @@ static int tests_passed = 0; /* Test empty string */ static int test_empty_string(void) { - uint8_t out[16]; - int check = dohd_url64_check(""); + int check = dohd_url64_check("", 0); TEST_ASSERT(check == 0, "Empty string check returns 0"); return 1; } /* Test NULL input */ static int test_null_input(void) { - int check = dohd_url64_check(NULL); + int check = dohd_url64_check(NULL, 0); TEST_ASSERT(check == 0, "NULL input check returns 0"); return 1; } @@ -63,10 +59,10 @@ static int test_single_char(void) { int len; /* Single valid char */ - int check = dohd_url64_check("A"); + int check = dohd_url64_check("A", 1); TEST_ASSERT(check == 1, "Single valid char check returns 1"); - len = dohd_url64_decode("A", out); + len = dohd_url64_decode("A", 1, out, sizeof(out)); TEST_ASSERT(len == 0, "Single char decode returns 0 bytes"); return 1; @@ -76,21 +72,21 @@ static int test_single_char(void) { static int test_invalid_chars(void) { int check; - check = dohd_url64_check("AAA="); /* Standard base64 padding */ + check = dohd_url64_check("AAA=", 4); /* Standard base64 padding */ TEST_ASSERT(check == 0, "Standard padding '=' is invalid for url64"); - check = dohd_url64_check("AAA+"); /* Standard base64 '+' */ + check = dohd_url64_check("AAA+", 4); /* Standard base64 '+' */ TEST_ASSERT(check == 0, "Standard '+' is invalid for url64"); /* Note: This implementation accepts '/' (maps to 63) for compatibility */ - check = dohd_url64_check("AAA!"); + check = dohd_url64_check("AAA!", 4); TEST_ASSERT(check == 0, "Invalid char '!' detected"); - check = dohd_url64_check("AAA "); + check = dohd_url64_check("AAA ", 4); TEST_ASSERT(check == 0, "Space char is invalid"); - check = dohd_url64_check("AAA\n"); + check = dohd_url64_check("AAA\n", 4); TEST_ASSERT(check == 0, "Newline char is invalid"); return 1; @@ -100,10 +96,10 @@ static int test_invalid_chars(void) { static int test_url64_chars(void) { int check; - check = dohd_url64_check("AAA-"); /* URL-safe minus */ + check = dohd_url64_check("AAA-", 4); /* URL-safe minus */ TEST_ASSERT(check == 4, "URL-safe '-' is valid"); - check = dohd_url64_check("AAA_"); /* URL-safe underscore */ + check = dohd_url64_check("AAA_", 4); /* URL-safe underscore */ TEST_ASSERT(check == 4, "URL-safe '_' is valid"); return 1; @@ -142,10 +138,10 @@ static int test_dns_query(void) { uint8_t out[64]; int len; - int check = dohd_url64_check(dns_b64); + int check = dohd_url64_check(dns_b64, strlen(dns_b64)); TEST_ASSERT(check == 39, "DNS query check returns correct length"); - len = dohd_url64_decode(dns_b64, out); + len = dohd_url64_decode(dns_b64, strlen(dns_b64), out, sizeof(out)); TEST_ASSERT(len == 29, "DNS query decode returns 29 bytes"); TEST_ASSERT(memcmp(out, expected, 29) == 0, "DNS query decode matches expected"); @@ -156,7 +152,7 @@ static int test_dns_query(void) { /* Test all alphanumeric characters */ static int test_all_valid_chars(void) { const char *all_valid = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"; - int check = dohd_url64_check(all_valid); + int check = dohd_url64_check(all_valid, strlen(all_valid)); TEST_ASSERT(check == 64, "All 64 valid characters accepted"); return 1; } @@ -173,10 +169,10 @@ static int test_long_input(void) { } long_input[2000] = '\0'; - int check = dohd_url64_check(long_input); + int check = dohd_url64_check(long_input, strlen(long_input)); TEST_ASSERT(check == 2000, "Long input check returns correct length"); - int len = dohd_url64_decode(long_input, out); + int len = dohd_url64_decode(long_input, strlen(long_input), out, sizeof(out)); TEST_ASSERT(len > 0, "Long input decode succeeds"); return 1; @@ -188,20 +184,27 @@ static int test_padding_scenarios(void) { int len; /* 2 chars = 1 byte output */ - len = dohd_url64_decode("QQ", out); + len = dohd_url64_decode("QQ", 2, out, sizeof(out)); TEST_ASSERT(len == 1, "2 char input -> 1 byte output"); /* 3 chars = 2 bytes output */ - len = dohd_url64_decode("QUE", out); + len = dohd_url64_decode("QUE", 3, out, sizeof(out)); TEST_ASSERT(len == 2, "3 char input -> 2 bytes output"); /* 4 chars = 3 bytes output */ - len = dohd_url64_decode("QUFB", out); + len = dohd_url64_decode("QUFB", 4, out, sizeof(out)); TEST_ASSERT(len == 3, "4 char input -> 3 bytes output"); return 1; } +static int test_decode_capacity_limit(void) { + uint8_t out[2]; + int len = dohd_url64_decode("QUFB", 4, out, sizeof(out)); + TEST_ASSERT(len == -1, "decode fails when destination capacity is too small"); + return 1; +} + int main(int argc, char **argv) { (void)argc; (void)argv; @@ -218,6 +221,7 @@ int main(int argc, char **argv) { test_all_valid_chars(); test_long_input(); test_padding_scenarios(); + test_decode_capacity_limit(); fprintf(stderr, "\n=== Results: %d/%d tests passed ===\n", tests_passed, tests_run);