@@ -51,6 +51,7 @@ import {
51
51
WebSocketHandlerDefinition ,
52
52
WsHandlerDefinitionLookup ,
53
53
} from './websocket-handler-definitions' ;
54
+ import { resetOrDestroy } from '../../util/socket-util' ;
54
55
55
56
export interface WebSocketHandler extends WebSocketHandlerDefinition {
56
57
handle (
@@ -157,19 +158,38 @@ function pipeWebSocket(inSocket: WebSocket, outSocket: WebSocket) {
157
158
} ) ;
158
159
}
159
160
160
- async function mirrorRejection ( socket : net . Socket , rejectionResponse : http . IncomingMessage ) {
161
- if ( socket . writable ) {
162
- const { statusCode, statusMessage, rawHeaders } = rejectionResponse ;
163
-
164
- socket . write (
165
- rawResponse ( statusCode || 500 , statusMessage || 'Unknown error' , pairFlatRawHeaders ( rawHeaders ) )
166
- ) ;
167
-
168
- const body = await streamToBuffer ( rejectionResponse ) ;
169
- if ( socket . writable ) socket . write ( body ) ;
170
- }
161
+ function mirrorRejection (
162
+ downstreamSocket : net . Socket ,
163
+ upstreamRejectionResponse : http . IncomingMessage ,
164
+ simulateConnectionErrors : boolean
165
+ ) {
166
+ return new Promise < void > ( ( resolve ) => {
167
+ if ( downstreamSocket . writable ) {
168
+ const { statusCode, statusMessage, rawHeaders } = upstreamRejectionResponse ;
169
+
170
+ downstreamSocket . write (
171
+ rawResponse ( statusCode || 500 , statusMessage || 'Unknown error' , pairFlatRawHeaders ( rawHeaders ) )
172
+ ) ;
173
+
174
+ upstreamRejectionResponse . pipe ( downstreamSocket ) ;
175
+ upstreamRejectionResponse . on ( 'end' , resolve ) ;
176
+ upstreamRejectionResponse . on ( 'error' , ( error ) => {
177
+ console . warn ( 'Error receiving WebSocket upstream rejection response:' , error ) ;
178
+ if ( simulateConnectionErrors ) {
179
+ resetOrDestroy ( downstreamSocket ) ;
180
+ } else {
181
+ downstreamSocket . destroy ( ) ;
182
+ }
183
+ resolve ( ) ;
184
+ } ) ;
171
185
172
- socket . destroy ( ) ;
186
+ // The socket is being optimistically written to and then killed - we don't care
187
+ // about any more errors occuring here.
188
+ downstreamSocket . on ( 'error' , ( ) => {
189
+ resolve ( ) ;
190
+ } ) ;
191
+ }
192
+ } ) . catch ( ( ) => { } ) ;
173
193
}
174
194
175
195
const rawResponse = (
@@ -402,25 +422,29 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
402
422
console . log ( `Unexpected websocket response from ${ wsUrl } : ${ res . statusCode } ` ) ;
403
423
404
424
// Clean up the downstream connection
405
- mirrorRejection ( incomingSocket , res ) ;
406
-
407
- // Clean up the upstream connection (WS would do this automatically, but doesn't if you listen to this event)
408
- // See https://github.com/websockets/ws/blob/45e17acea791d865df6b255a55182e9c42e5877a/lib/websocket.js#L1050
409
- // We don't match that perfectly, but this should be effectively equivalent:
410
- req . destroy ( ) ;
411
- if ( req . socket && ! req . socket . destroyed ) {
412
- res . socket . destroy ( ) ;
413
- }
414
- unexpectedResponse = true ; // So that we ignore this in the error handler
415
- upstreamWebSocket . terminate ( ) ;
425
+ mirrorRejection ( incomingSocket , res , this . simulateConnectionErrors ) . then ( ( ) => {
426
+ // Clean up the upstream connection (WS would do this automatically, but doesn't if you listen to this event)
427
+ // See https://github.com/websockets/ws/blob/45e17acea791d865df6b255a55182e9c42e5877a/lib/websocket.js#L1050
428
+ // We don't match that perfectly, but this should be effectively equivalent:
429
+ req . destroy ( ) ;
430
+ if ( res . socket ?. destroyed === false ) {
431
+ res . socket . destroy ( ) ;
432
+ }
433
+ unexpectedResponse = true ; // So that we ignore this in the error handler
434
+ upstreamWebSocket . terminate ( ) ;
435
+ } ) ;
416
436
} ) ;
417
437
418
438
// If there's some other error, we just kill the socket:
419
439
upstreamWebSocket . on ( 'error' , ( e ) => {
420
440
if ( unexpectedResponse ) return ; // Handled separately above
421
441
422
442
console . warn ( e ) ;
423
- incomingSocket . end ( ) ;
443
+ if ( this . simulateConnectionErrors ) {
444
+ resetOrDestroy ( incomingSocket ) ;
445
+ } else {
446
+ incomingSocket . end ( ) ;
447
+ }
424
448
} ) ;
425
449
426
450
incomingSocket . on ( 'error' , ( ) => upstreamWebSocket . close ( 1011 ) ) ; // Internal error
@@ -438,6 +462,7 @@ export class PassThroughWebSocketHandler extends PassThroughWebSocketHandlerDefi
438
462
return _ . create ( this . prototype , {
439
463
...data ,
440
464
proxyConfig : deserializeProxyConfig ( data . proxyConfig , channel , ruleParams ) ,
465
+ simulateConnectionErrors : data . simulateConnectionErrors ?? false ,
441
466
extraCACertificates : data . extraCACertificates || [ ] ,
442
467
ignoreHostHttpsErrors : data . ignoreHostCertificateErrors ,
443
468
clientCertificateHostMap : _ . mapValues ( data . clientCertificateHostMap ,
0 commit comments