@@ -63,6 +63,8 @@ typedef struct {
6363 char * sub_protocol ;
6464 char * user_agent ;
6565 char * headers ;
66+ ws_header_hook header_hook ;
67+ void * header_userp ;
6668 char * auth ;
6769 char * buffer ; /*!< Initial HTTP connection buffer, which may include data beyond the handshake headers, such as the next WebSocket packet*/
6870 size_t buffer_len ; /*!< The buffer length */
@@ -144,29 +146,31 @@ static int esp_transport_read_internal(transport_ws_t *ws, char *buffer, int len
144146 return to_read ;
145147}
146148
147- static char * trimwhitespace ( char * str )
149+ static int esp_transport_read_blocking ( transport_ws_t * ws , char * buffer , int len , int timeout_ms )
148150{
149- char * end ;
150-
151- // Trim leading space
152- while (isspace ((unsigned char )* str )) {
153- str ++ ;
154- }
151+ int orig_len = len ;
152+ int did_read = 0 ;
153+
154+ //Sometimes we get 1B/read, sometimes we need 8B?
155+ //Never seen this need >2 reads, better be safe though - need those headers.
156+ const unsigned MAX_ATTEMPT = 8 ;
157+ //TODO: change this(or the underlying transports) to retry only until timeout_ms expires.
158+ unsigned attempt = 0 ;
159+ for (attempt = 0 ; attempt < MAX_ATTEMPT && len > 0 ; ++ attempt ){
160+ int tmp = esp_transport_read_internal (ws , buffer , len , timeout_ms );
161+ if (tmp < 0 ){
162+ return tmp ;
163+ }
155164
156- if (* str == 0 ) {
157- return str ;
165+ buffer += tmp ;
166+ len -= tmp ;
167+ did_read += tmp ;
158168 }
159169
160- // Trim trailing space
161- end = str + strlen (str ) - 1 ;
162- while (end > str && isspace ((unsigned char )* end )) {
163- end -- ;
170+ if (attempt > 1 ){
171+ ESP_LOGI (TAG , "tried %u attempts to read %i bytes. did read %i bytes" , attempt , orig_len , did_read );
164172 }
165-
166- // Write new null terminator
167- * (end + 1 ) = '\0' ;
168-
169- return str ;
173+ return did_read ;
170174}
171175
172176static int get_http_status_code (const char * buffer )
@@ -189,21 +193,6 @@ static int get_http_status_code(const char *buffer)
189193 return -1 ;
190194}
191195
192- static char * get_http_header (char * buffer , const char * key )
193- {
194- char * found = strcasestr (buffer , key );
195- if (found ) {
196- found += strlen (key );
197- char * found_end = strstr (found , "\r\n" );
198- if (found_end ) {
199- * found_end = '\0' ; // terminal string
200-
201- return trimwhitespace (found );
202- }
203- }
204- return NULL ;
205- }
206-
207196static int ws_connect (esp_transport_handle_t t , const char * host , int port , int timeout_ms )
208197{
209198 transport_ws_t * ws = esp_transport_get_context_data (t );
@@ -336,7 +325,45 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
336325 return ws -> http_status_code ;
337326 }
338327
339- char * server_key = get_http_header (ws -> buffer , "Sec-WebSocket-Accept:" );
328+ const char * server_key = NULL ;
329+ int server_key_len = 0 ;
330+ const char * header_cursor = strnstr (ws -> buffer , "\r\n" , header_len );
331+ if (!header_cursor ){
332+ ESP_LOGE (TAG , "HTTP Header locate failed" );
333+ return -1 ;
334+ }
335+ header_cursor += strlen ("\r\n" );
336+
337+ while (header_cursor < delim_ptr ){
338+ const char * end_of_line = strnstr (header_cursor , "\r\n" , header_len - (header_cursor - ws -> buffer ));
339+ if (!end_of_line ){
340+ ESP_LOGE (TAG , "HTTP Header walk failed" );
341+ return -1 ;
342+ }
343+ else if (end_of_line == header_cursor ){
344+ ESP_LOGD (TAG , "HTTP Header walk found end" );
345+ break ;
346+ }
347+ int line_len = end_of_line - header_cursor ;
348+ ESP_LOGD (TAG , "HTTP Header walk line:%.*s" , line_len , header_cursor );
349+
350+ // Find the Sec-WebSocket-Accept header
351+ const char * header_sec_websocket_accept = "Sec-WebSocket-Accept: " ;
352+ size_t header_sec_websocket_accept_len = strlen (header_sec_websocket_accept );
353+ if (line_len >= header_sec_websocket_accept_len && !strncmp (header_cursor , header_sec_websocket_accept , header_sec_websocket_accept_len )) {
354+ ESP_LOGD (TAG , "found server-key" );
355+ server_key = header_cursor + header_sec_websocket_accept_len ;
356+ server_key_len = line_len - header_sec_websocket_accept_len ;
357+ }
358+ else if (ws -> header_hook ) {
359+ ws -> header_hook (ws -> header_userp , header_cursor , line_len );
360+ }
361+
362+ // Adjust cursor to the start of the next line
363+ header_cursor += line_len ;
364+ header_cursor += strlen ("\r\n" );
365+ }
366+
340367 if (server_key == NULL ) {
341368 ESP_LOGE (TAG , "Sec-WebSocket-Accept not found" );
342369 return -1 ;
@@ -357,7 +384,7 @@ static int ws_connect(esp_transport_handle_t t, const char *host, int port, int
357384 esp_crypto_base64_encode (expected_server_key , sizeof (expected_server_key ), & outlen , expected_server_sha1 , sizeof (expected_server_sha1 ));
358385 expected_server_key [ (outlen < sizeof (expected_server_key )) ? outlen : (sizeof (expected_server_key ) - 1 ) ] = 0 ;
359386 ESP_LOGD (TAG , "server key=%s, send_key=%s, expected_server_key=%s" , (char * )server_key , (char * )client_key , expected_server_key );
360- if (strcmp ((char * )expected_server_key , (char * )server_key ) != 0 ) {
387+ if (strncmp ((char * )expected_server_key , (char * )server_key , server_key_len ) != 0 ) {
361388 ESP_LOGE (TAG , "Invalid websocket key" );
362389 return -1 ;
363390 }
@@ -862,6 +889,26 @@ esp_err_t esp_transport_ws_set_headers(esp_transport_handle_t t, const char *hea
862889 return ESP_OK ;
863890}
864891
892+ esp_err_t esp_transport_ws_set_header_hook (esp_transport_handle_t t , ws_header_hook hook )
893+ {
894+ if (t == NULL ) {
895+ return ESP_ERR_INVALID_ARG ;
896+ }
897+ transport_ws_t * ws = esp_transport_get_context_data (t );
898+ ws -> header_hook = hook ;
899+ return ESP_OK ;
900+ }
901+
902+ esp_err_t esp_transport_ws_set_header_userp (esp_transport_handle_t t , void * userp )
903+ {
904+ if (t == NULL ) {
905+ return ESP_ERR_INVALID_ARG ;
906+ }
907+ transport_ws_t * ws = esp_transport_get_context_data (t );
908+ ws -> header_userp = userp ;
909+ return ESP_OK ;
910+ }
911+
865912esp_err_t esp_transport_ws_set_auth (esp_transport_handle_t t , const char * auth )
866913{
867914 if (t == NULL ) {
@@ -927,6 +974,14 @@ esp_err_t esp_transport_ws_set_config(esp_transport_handle_t t, const esp_transp
927974 err = esp_transport_ws_set_headers (t , config -> headers );
928975 ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
929976 }
977+ if (config -> header_hook ) {
978+ err = esp_transport_ws_set_header_hook (t , config -> header_hook );
979+ ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
980+ }
981+ if (config -> header_userp ) {
982+ err = esp_transport_ws_set_header_userp (t , config -> header_userp );
983+ ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
984+ }
930985 if (config -> auth ) {
931986 err = esp_transport_ws_set_auth (t , config -> auth );
932987 ESP_TRANSPORT_ERR_OK_CHECK (TAG , err , return err ;)
0 commit comments