Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
236 changes: 236 additions & 0 deletions __tests__/cancellation.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -790,6 +790,242 @@ describe.each(testMatrix())(
},
);

describe.each(testMatrix())('handler explicit uncaught error cancellation ($transport.name transport, $codec.name codec)',
async ({ transport, codec }) => {
const opts = { codec: codec.codec };

const { addPostTestCleanup, postTestCleanup } = createPostTestCleanups();
let getClientTransport: TestSetupHelpers['getClientTransport'];
let getServerTransport: TestSetupHelpers['getServerTransport'];
beforeEach(async () => {
const setup = await transport.setup({ client: opts, server: opts });
getClientTransport = setup.getClientTransport;
getServerTransport = setup.getServerTransport;

return async () => {
await postTestCleanup();
await setup.cleanup();
};
});

describe('e2e', () => {
test('rpc', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

const handler = makeMockHandler('rpc');
const services = {
service: ServiceSchema.define({
rpc: Procedure.rpc({
requestInit: Type.Object({}),
responseData: Type.Object({}),
handler,
}),
}),
};

const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

const resP = client.service.rpc.rpc({});

await waitFor(() => {
expect(handler).toHaveBeenCalledTimes(1);
});

const [{ ctx }] = handler.mock.calls[0];
const onRequestFinished = vi.fn();
ctx.signal.addEventListener('abort', onRequestFinished);

const err = ctx.uncaught(new Error('test'));

expect(err).toEqual(
Err({
code: UNCAUGHT_ERROR_CODE,
message: 'test',
}),
);

await waitFor(() => {
expect(onRequestFinished).toHaveBeenCalled();
});
await expect(resP).resolves.toEqual(err);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('stream', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

const handler = makeMockHandler('stream');
const services = {
service: ServiceSchema.define({
stream: Procedure.stream({
requestInit: Type.Object({}),
requestData: Type.Object({}),
responseData: Type.Object({}),
handler,
}),
}),
};

const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

const { reqWritable, resReadable } = client.service.stream.stream({});

await waitFor(() => {
expect(handler).toHaveBeenCalledTimes(1);
});

const [{ ctx, reqReadable, resWritable }] = handler.mock.calls[0];

const err = ctx.uncaught(new Error('test'));

expect(err).toEqual(
Err({
code: UNCAUGHT_ERROR_CODE,
message: 'test',
}),
);

expect(await reqReadable.collect()).toEqual([err]);
expect(resWritable.isWritable()).toEqual(false);

expect(await resReadable.collect()).toEqual([err]);
expect(reqWritable.isWritable()).toEqual(false);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('upload', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

const handler = makeMockHandler('upload');
const services = {
service: ServiceSchema.define({
upload: Procedure.upload({
requestInit: Type.Object({}),
requestData: Type.Object({}),
responseData: Type.Object({}),
handler,
}),
}),
};

const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

const { reqWritable, finalize } = client.service.upload.upload({});

await waitFor(() => {
expect(handler).toHaveBeenCalledTimes(1);
});

const [{ ctx, reqReadable }] = handler.mock.calls[0];

const err = ctx.uncaught(new Error('test'));

expect(err).toEqual(
Err({
code: UNCAUGHT_ERROR_CODE,
message: 'test',
}),
);

expect(await finalize()).toEqual(err);
expect(reqWritable.isWritable()).toEqual(false);
expect(await reqReadable.collect()).toEqual([err]);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});

test('subscribe', async () => {
const clientTransport = getClientTransport('client');
const serverTransport = getServerTransport();
addPostTestCleanup(async () => {
await cleanupTransports([clientTransport, serverTransport]);
});

const handler = makeMockHandler('subscription');
const services = {
service: ServiceSchema.define({
subscribe: Procedure.subscription({
requestInit: Type.Object({}),
responseData: Type.Object({}),
handler,
}),
}),
};

const server = createServer(serverTransport, services);
const client = createClient<typeof services>(
clientTransport,
serverTransport.clientId,
);

const { resReadable } = client.service.subscribe.subscribe({});

await waitFor(() => {
expect(handler).toHaveBeenCalledTimes(1);
});

const [{ ctx, resWritable }] = handler.mock.calls[0];

const err = ctx.uncaught(new Error('test'));

expect(err).toEqual(
Err({
code: UNCAUGHT_ERROR_CODE,
message: 'test',
}),
);

expect(await resReadable.collect()).toEqual([err]);
expect(resWritable.isWritable()).toEqual(false);

await testFinishesCleanly({
clientTransports: [clientTransport],
serverTransport,
server,
});
});
});
},
);

const createRejectable = () => {
let reject: (reason: Error) => void;
const promise = new Promise<void>((_res, rej) => {
Expand Down
11 changes: 10 additions & 1 deletion router/context.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ import { Span } from '@opentelemetry/api';
import { TransportClientId } from '../transport/message';
import { SessionId } from '../transport/sessionStateMachine/common';
import { ErrResult } from './result';
import { CancelErrorSchema } from './errors';
import { CancelErrorSchema, UncaughtErrorSchema } from './errors';
import { Static } from '@sinclair/typebox';

/**
Expand Down Expand Up @@ -40,6 +40,15 @@ export type ProcedureHandlerContext<State, Context, ParsedMetadata> =
* the river documentation to understand the difference between the two concepts.
*/
cancel: (message?: string) => ErrResult<Static<typeof CancelErrorSchema>>;
/**
* This emits an uncaught error in the same way that throwing an error in a handler
* would. You should minimize the amount of work you do after calling this function
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can create an eslint rule for this

* as this will start a cleanup of the entire procedure call.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe clarify this ends the readable/writables

*
* You'll typically want to use this for streaming procedures, as in e.g. an RPC
* you can just throw instead.
Comment on lines +48 to +49
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
* You'll typically want to use this for streaming procedures, as in e.g. an RPC
* you can just throw instead.
* You'll typically want to use this for server-sent stream procedures like subscriptions or streams which may handle things inside closures where throwing will not be caught by River's procedure uncaught handler.

*/
uncaught: (err?: unknown) => ErrResult<Static<typeof UncaughtErrorSchema>>;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what does it mean to throw an uncaught with no err?

/**
* This signal is a standard [AbortSignal](https://developer.mozilla.org/en-US/docs/Web/API/AbortSignal)
* triggered when the procedure invocation is done. This signal tracks the invocation/request finishing
Expand Down
15 changes: 11 additions & 4 deletions router/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,16 @@ export function castTypeboxValueErrors(
return result;
}

/**
* A schema for unexpected errors in handlers
*/
export const UncaughtErrorSchema = Type.Object({
code: Type.Literal(UNCAUGHT_ERROR_CODE),
message: Type.String(),
});

export const UncaughtResultSchema = ErrResultSchema(UncaughtErrorSchema);

/**
* A schema for cancel payloads sent from the client
*/
Expand All @@ -88,10 +98,7 @@ export const CancelResultSchema = ErrResultSchema(CancelErrorSchema);
* on the client).
*/
export const ReaderErrorSchema = Type.Union([
Type.Object({
code: Type.Literal(UNCAUGHT_ERROR_CODE),
message: Type.String(),
}),
UncaughtErrorSchema,
Type.Object({
code: Type.Literal(UNEXPECTED_DISCONNECT_CODE),
message: Type.String(),
Expand Down
14 changes: 11 additions & 3 deletions router/server.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import {
ValidationErrors,
castTypeboxValueErrors,
CancelResultSchema,
UncaughtResultSchema,
} from './errors';
import {
AnyService,
Expand Down Expand Up @@ -550,7 +551,7 @@ class RiverServer<
},
});

const onHandlerError = (err: unknown, span: Span) => {
const onHandlerError = (err: unknown, span: Span): Static<typeof UncaughtResultSchema> => {
const errorMsg = coerceErrorString(err);

span.recordException(err instanceof Error ? err : new Error(errorMsg));
Expand All @@ -571,10 +572,14 @@ class RiverServer<
},
);

onServerCancel({
const res = Err({
code: UNCAUGHT_ERROR_CODE,
message: errorMsg,
});

onServerCancel(res.payload);

return res;
};

// if the init message has a close flag then we know this stream
Expand Down Expand Up @@ -603,6 +608,9 @@ class RiverServer<

return Err(errRes);
},
uncaught: (err?: unknown) => {
return onHandlerError(err, span);
},
signal: finishedController.signal,
};

Expand Down Expand Up @@ -1039,7 +1047,7 @@ function getStreamCloseBackwardsCompat(protocolVersion: ProtocolVersion) {

export interface MiddlewareContext
extends Readonly<
Omit<ProcedureHandlerContext<unknown, unknown, unknown>, 'cancel'>
Omit<ProcedureHandlerContext<unknown, unknown, unknown>, 'cancel' | 'uncaught'>
> {
readonly streamId: StreamId;
readonly procedureName: string;
Expand Down
Loading