Skip to content

Commit f8c65f8

Browse files
feat(ws_transport): add header callback hook
Part of espressif/esp-protocols#715
1 parent adcbdd7 commit f8c65f8

File tree

2 files changed

+119
-35
lines changed

2 files changed

+119
-35
lines changed

components/tcp_transport/include/esp_transport_ws.h

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ typedef enum ws_transport_opcodes {
2929
* from the API esp_transport_ws_get_read_opcode() */
3030
} ws_transport_opcodes_t;
3131

32+
typedef void (*ws_header_hook)(void * userp, const char * line, int line_len);
33+
3234
/**
3335
* WS transport configuration structure
3436
*/
@@ -37,6 +39,8 @@ typedef struct {
3739
const char *sub_protocol; /*!< WS subprotocol */
3840
const char *user_agent; /*!< WS user agent */
3941
const char *headers; /*!< WS additional headers */
42+
ws_header_hook header_hook; /*!< WS received header */
43+
void *header_userp; /*!< WS received header user-pointer */
4044
const char *auth; /*!< HTTP authorization header */
4145
char *response_headers; /*!< The buffer to copy the http response header */
4246
size_t response_headers_len; /*!< The length of the http response header */
@@ -99,6 +103,31 @@ esp_err_t esp_transport_ws_set_user_agent(esp_transport_handle_t t, const char *
99103
*/
100104
esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *headers);
101105

106+
/**
107+
* @brief Set websocket header callback
108+
*
109+
* @param t websocket transport handle
110+
* @param hook call function on header received. NULL to disable.
111+
*
112+
* @return
113+
* - ESP_OK on success
114+
* - One of the error codes
115+
*/
116+
esp_err_t esp_transport_ws_set_header_hook(esp_transport_handle_t t, ws_header_hook hook);
117+
118+
119+
/**
120+
* @brief Set websocket header callback user-pointer
121+
*
122+
* @param t websocket transport handle
123+
* @param userp caller-controlled argument to ws_header_hook
124+
*
125+
* @return
126+
* - ESP_OK on success
127+
* - One of the error codes
128+
*/
129+
esp_err_t esp_transport_ws_set_header_userp(esp_transport_handle_t t, void * userp);
130+
102131
/**
103132
* @brief Set websocket authorization headers
104133
*

components/tcp_transport/transport_ws.c

Lines changed: 90 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,8 @@ typedef struct {
6363
char *sub_protocol;
6464
char *user_agent;
6565
char *headers;
66+
ws_header_hook header_hook;
67+
void * header_userp;
6668
char *auth;
6769
char *buffer; /*!< Initial HTTP connection buffer, which may include data beyond the handshake headers, such as the next WebSocket packet*/
6870
size_t buffer_len; /*!< The buffer length */
@@ -144,29 +146,31 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len
144146
return to_read;
145147
}
146148

147-
static char *trimwhitespace(char *str)
149+
static int esp_transport_read_blocking(transport_ws_t *ws, char *buffer, int len, int timeout_ms)
148150
{
149-
char *end;
150-
151-
// Trim leading space
152-
while (isspace((unsigned char)*str)) {
153-
str++;
154-
}
151+
int orig_len = len;
152+
int did_read = 0;
153+
154+
//Sometimes we get 1B/read, sometimes we need 8B?
155+
//Never seen this need >2 reads, better be safe though - need those headers.
156+
const unsigned MAX_ATTEMPT = 8;
157+
//TODO: change this(or the underlying transports) to retry only until timeout_ms expires.
158+
unsigned attempt = 0;
159+
for(attempt = 0; attempt < MAX_ATTEMPT && len > 0; ++attempt){
160+
int tmp = esp_transport_read_internal(ws, buffer, len, timeout_ms);
161+
if(tmp < 0){
162+
return tmp;
163+
}
155164

156-
if (*str == 0) {
157-
return str;
165+
buffer += tmp;
166+
len -= tmp;
167+
did_read += tmp;
158168
}
159169

160-
// Trim trailing space
161-
end = str + strlen(str) - 1;
162-
while (end > str && isspace((unsigned char)*end)) {
163-
end--;
170+
if(attempt > 1){
171+
ESP_LOGI(TAG, "tried %u attempts to read %i bytes. did read %i bytes", attempt, orig_len, did_read);
164172
}
165-
166-
// Write new null terminator
167-
*(end + 1) = '\0';
168-
169-
return str;
173+
return did_read;
170174
}
171175

172176
static int get_http_status_code(const char *buffer)
@@ -189,21 +193,6 @@ static int get_http_status_code(const char *buffer)
189193
return -1;
190194
}
191195

192-
static char *get_http_header(char *buffer, const char *key)
193-
{
194-
char *found = strcasestr(buffer, key);
195-
if (found) {
196-
found += strlen(key);
197-
char *found_end = strstr(found, "\r\n");
198-
if (found_end) {
199-
*found_end = '\0'; // terminal string
200-
201-
return trimwhitespace(found);
202-
}
203-
}
204-
return NULL;
205-
}
206-
207196
static int ws_connect(esp_transport_handle_t t, const char *host, int port, int timeout_ms)
208197
{
209198
transport_ws_t *ws = esp_transport_get_context_data(t);
@@ -336,7 +325,45 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
336325
return ws->http_status_code;
337326
}
338327

339-
char *server_key = get_http_header(ws->buffer, "Sec-WebSocket-Accept:");
328+
const char *server_key = NULL;
329+
int server_key_len = 0;
330+
const char * header_cursor = strnstr(ws->buffer, "\r\n", header_len);
331+
if (!header_cursor){
332+
ESP_LOGE(TAG, "HTTP Header locate failed");
333+
return -1;
334+
}
335+
header_cursor += strlen("\r\n");
336+
337+
while(header_cursor < delim_ptr){
338+
const char * end_of_line = strnstr(header_cursor, "\r\n", header_len - (header_cursor - ws->buffer));
339+
if(!end_of_line){
340+
ESP_LOGE(TAG, "HTTP Header walk failed");
341+
return -1;
342+
}
343+
else if(end_of_line == header_cursor){
344+
ESP_LOGD(TAG, "HTTP Header walk found end");
345+
break;
346+
}
347+
int line_len = end_of_line - header_cursor;
348+
ESP_LOGD(TAG, "HTTP Header walk line:%.*s", line_len, header_cursor);
349+
350+
// Find the Sec-WebSocket-Accept header
351+
const char * header_sec_websocket_accept = "Sec-WebSocket-Accept: ";
352+
size_t header_sec_websocket_accept_len = strlen(header_sec_websocket_accept);
353+
if (line_len >= header_sec_websocket_accept_len && !strncmp(header_cursor, header_sec_websocket_accept, header_sec_websocket_accept_len)) {
354+
ESP_LOGD(TAG, "found server-key");
355+
server_key = header_cursor + header_sec_websocket_accept_len;
356+
server_key_len = line_len - header_sec_websocket_accept_len;
357+
}
358+
else if (ws->header_hook) {
359+
ws->header_hook(ws->header_userp, header_cursor, line_len);
360+
}
361+
362+
// Adjust cursor to the start of the next line
363+
header_cursor += line_len;
364+
header_cursor += strlen("\r\n");
365+
}
366+
340367
if (server_key == NULL) {
341368
ESP_LOGE(TAG, "Sec-WebSocket-Accept not found");
342369
return -1;
@@ -357,7 +384,7 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
357384
esp_crypto_base64_encode(expected_server_key, sizeof(expected_server_key), &outlen, expected_server_sha1, sizeof(expected_server_sha1));
358385
expected_server_key[ (outlen < sizeof(expected_server_key)) ? outlen : (sizeof(expected_server_key) - 1) ] = 0;
359386
ESP_LOGD(TAG, "server key=%s, send_key=%s, expected_server_key=%s", (char *)server_key, (char *)client_key, expected_server_key);
360-
if (strcmp((char *)expected_server_key, (char *)server_key) != 0) {
387+
if (strncmp((char *)expected_server_key, (char *)server_key, server_key_len) != 0) {
361388
ESP_LOGE(TAG, "Invalid websocket key");
362389
return -1;
363390
}
@@ -862,6 +889,26 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea
862889
return ESP_OK;
863890
}
864891

892+
esp_err_t esp_transport_ws_set_header_hook(esp_transport_handle_t t, ws_header_hook hook)
893+
{
894+
if (t == NULL) {
895+
return ESP_ERR_INVALID_ARG;
896+
}
897+
transport_ws_t *ws = esp_transport_get_context_data(t);
898+
ws->header_hook = hook;
899+
return ESP_OK;
900+
}
901+
902+
esp_err_t esp_transport_ws_set_header_userp(esp_transport_handle_t t, void * userp)
903+
{
904+
if (t == NULL) {
905+
return ESP_ERR_INVALID_ARG;
906+
}
907+
transport_ws_t *ws = esp_transport_get_context_data(t);
908+
ws->header_userp = userp;
909+
return ESP_OK;
910+
}
911+
865912
esp_err_t esp_transport_ws_set_auth(esp_transport_handle_t t, const char *auth)
866913
{
867914
if (t == NULL) {
@@ -927,6 +974,14 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp
927974
err = esp_transport_ws_set_headers(t, config->headers);
928975
ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
929976
}
977+
if (config->header_hook) {
978+
err = esp_transport_ws_set_header_hook(t, config->header_hook);
979+
ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
980+
}
981+
if (config->header_userp) {
982+
err = esp_transport_ws_set_header_userp(t, config->header_userp);
983+
ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)
984+
}
930985
if (config->auth) {
931986
err = esp_transport_ws_set_auth(t, config->auth);
932987
ESP_TRANSPORT_ERR_OK_CHECK(TAG, err, return err;)

0 commit comments

Comments
 (0)