@@ -63,6 +63,8 @@ typedef struct {
6363 char * sub_protocol ;
6464 char * user_agent ;
6565 char * headers ;
66+ ws_header_hook header_hook ;
67+ void * header_userp ;
6668 char * auth ;
6769 char * buffer ; /*!< Initial HTTP connection buffer, which may include data beyond the handshake headers, such as the next WebSocket packet*/
6870 size_t buffer_len ; /*!< The buffer length */
@@ -144,31 +146,6 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len
144146 return to_read ;
145147}
146148
147- static char * trimwhitespace (char * str )
148- {
149- char * end ;
150-
151- // Trim leading space
152- while (isspace ((unsigned char )* str )) {
153- str ++ ;
154- }
155-
156- if (* str == 0 ) {
157- return str ;
158- }
159-
160- // Trim trailing space
161- end = str + strlen (str ) - 1 ;
162- while (end > str && isspace ((unsigned char )* end )) {
163- end -- ;
164- }
165-
166- // Write new null terminator
167- * (end + 1 ) = '\0' ;
168-
169- return str ;
170- }
171-
172149static int get_http_status_code (const char * buffer )
173150{
174151 const char http [] = "HTTP/" ;
@@ -189,21 +166,6 @@ static int get_http_status_code(const char *buffer)
189166 return -1 ;
190167}
191168
192- static char * get_http_header (char * buffer , const char * key )
193- {
194- char * found = strcasestr (buffer , key );
195- if (found ) {
196- found += strlen (key );
197- char * found_end = strstr (found , "\r\n" );
198- if (found_end ) {
199- * found_end = '\0' ; // terminal string
200-
201- return trimwhitespace (found );
202- }
203- }
204- return NULL ;
205- }
206-
207169static int ws_connect (esp_transport_handle_t t , const char * host , int port , int timeout_ms )
208170{
209171 transport_ws_t * ws = esp_transport_get_context_data (t );
@@ -330,17 +292,67 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
330292 if (ws -> http_status_code == -1 ) {
331293 ESP_LOGE (TAG , "HTTP upgrade failed" );
332294 return -1 ;
333- } else if (WS_HTTP_TEMPORARY_REDIRECT (ws -> http_status_code ) || WS_HTTP_PERMANENT_REDIRECT (ws -> http_status_code )) {
334- char * redir_host = get_http_header (ws -> buffer , "Location:" );
335- if (redir_host == NULL ) {
295+ }
296+
297+ const char * location = NULL ;
298+ int location_len = 0 ;
299+
300+ const char * server_key = NULL ;
301+ int server_key_len = 0 ;
302+ const char * header_cursor = strnstr (ws -> buffer , "\r\n" , header_len );
303+ if (!header_cursor ){
304+ ESP_LOGE (TAG , "HTTP Header locate failed" );
305+ return -1 ;
306+ }
307+ header_cursor += strlen ("\r\n" );
308+
309+ while (header_cursor < delim_ptr ){
310+ const char * end_of_line = strnstr (header_cursor , "\r\n" , header_len - (header_cursor - ws -> buffer ));
311+ if (!end_of_line ){
312+ ESP_LOGE (TAG , "HTTP Header walk failed" );
313+ return -1 ;
314+ }
315+ else if (end_of_line == header_cursor ){
316+ ESP_LOGD (TAG , "HTTP Header walk found end" );
317+ break ;
318+ }
319+ int line_len = end_of_line - header_cursor ;
320+ ESP_LOGD (TAG , "HTTP Header walk line:%.*s" , line_len , header_cursor );
321+
322+ // Check for Sec-WebSocket-Accept header
323+ const char * header_sec_websocket_accept = "Sec-WebSocket-Accept: " ;
324+ size_t header_sec_websocket_accept_len = strlen (header_sec_websocket_accept );
325+ if (line_len >= header_sec_websocket_accept_len && !strncasecmp (header_cursor , header_sec_websocket_accept , header_sec_websocket_accept_len )) {
326+ ESP_LOGD (TAG , "found server-key" );
327+ server_key = header_cursor + header_sec_websocket_accept_len ;
328+ server_key_len = line_len - header_sec_websocket_accept_len ;
329+ }
330+ else if (ws -> header_hook ) {
331+ ws -> header_hook (ws -> header_userp , header_cursor , line_len );
332+ }
333+
334+ // Check for Location: header
335+ const char * header_location = "Location: " ;
336+ size_t header_location_len = strlen (header_location );
337+ if (line_len >= header_location_len && !strncasecmp (header_cursor , header_location , header_location_len )) {
338+ location = header_cursor + header_location_len ;
339+ location_len = line_len - header_location_len ;
340+ }
341+
342+ // Adjust cursor to the start of the next line
343+ header_cursor += line_len ;
344+ header_cursor += strlen ("\r\n" );
345+ }
346+
347+ if (WS_HTTP_TEMPORARY_REDIRECT (ws -> http_status_code ) || WS_HTTP_PERMANENT_REDIRECT (ws -> http_status_code )) {
348+ if (location == NULL || location_len <= 0 ) {
336349 ESP_LOGE (TAG , "Location header not found" );
337350 return -1 ;
338351 }
339- ws -> redir_host = strdup ( redir_host );
352+ ws -> redir_host = strndup ( location , location_len );
340353 return ws -> http_status_code ;
341354 }
342355
343- char * server_key = get_http_header (ws -> buffer , "Sec-WebSocket-Accept:" );
344356 if (server_key == NULL ) {
345357 ESP_LOGE (TAG , "Sec-WebSocket-Accept not found" );
346358 return -1 ;
@@ -361,7 +373,7 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
361373 esp_crypto_base64_encode (expected_server_key , sizeof (expected_server_key ), & outlen , expected_server_sha1 , sizeof (expected_server_sha1 ));
362374 expected_server_key [ (outlen < sizeof (expected_server_key )) ? outlen : (sizeof (expected_server_key ) - 1 ) ] = 0 ;
363375 ESP_LOGD (TAG , "server key=%s, send_key=%s, expected_server_key=%s" , (char * )server_key , (char * )client_key , expected_server_key );
364- if (strcmp ((char * )expected_server_key , (char * )server_key ) != 0 ) {
376+ if (strncmp ((char * )expected_server_key , (char * )server_key , server_key_len ) != 0 ) {
365377 ESP_LOGE (TAG , "Invalid websocket key" );
366378 return -1 ;
367379 }
@@ -866,6 +878,26 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea
866878 return ESP_OK ;
867879}
868880
881+ esp_err_t esp_transport_ws_set_header_hook (esp_transport_handle_t t , ws_header_hook hook )
882+ {
883+ if (t == NULL ) {
884+ return ESP_ERR_INVALID_ARG ;
885+ }
886+ transport_ws_t * ws = esp_transport_get_context_data (t );
887+ ws -> header_hook = hook ;
888+ return ESP_OK ;
889+ }
890+
891+ esp_err_t esp_transport_ws_set_header_userp (esp_transport_handle_t t , void * userp )
892+ {
893+ if (t == NULL ) {
894+ return ESP_ERR_INVALID_ARG ;
895+ }
896+ transport_ws_t * ws = esp_transport_get_context_data (t );
897+ ws -> header_userp = userp ;
898+ return ESP_OK ;
899+ }
900+
869901esp_err_t esp_transport_ws_set_auth (esp_transport_handle_t t , const char * auth )
870902{
871903 if (t == NULL ) {
@@ -931,6 +963,14 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp
931963 err = esp_transport_ws_set_headers (t , config -> headers );
932964 ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
933965 }
966+ if (config -> header_hook ) {
967+ err = esp_transport_ws_set_header_hook (t , config -> header_hook );
968+ ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
969+ }
970+ if (config -> header_userp ) {
971+ err = esp_transport_ws_set_header_userp (t , config -> header_userp );
972+ ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
973+ }
934974 if (config -> auth ) {
935975 err = esp_transport_ws_set_auth (t , config -> auth );
936976 ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
0 commit comments