diff --git a/.changeset/lemon-ties-try.md b/.changeset/lemon-ties-try.md new file mode 100644 index 000000000..3d458ebf0 --- /dev/null +++ b/.changeset/lemon-ties-try.md @@ -0,0 +1,6 @@ +--- +'@livekit/agents-plugin-bey': patch +'@livekit/agents': patch +--- + +Add Session Connection Options and Fix Blocking Speech from High-latency LLM Generation diff --git a/agents/src/inference/llm.ts b/agents/src/inference/llm.ts index b6601d495..64b952af0 100644 --- a/agents/src/inference/llm.ts +++ b/agents/src/inference/llm.ts @@ -149,7 +149,6 @@ export class LLM extends llm.LLM { this.client = new OpenAI({ baseURL: this.opts.baseURL, apiKey: '', // leave a temporary empty string to avoid OpenAI complain about missing key - timeout: 15000, }); } diff --git a/agents/src/inference/tts.ts b/agents/src/inference/tts.ts index 7bd9b5e24..c9432b097 100644 --- a/agents/src/inference/tts.ts +++ b/agents/src/inference/tts.ts @@ -297,14 +297,17 @@ export class SynthesizeStream extends BaseSynthesizeSt const createInputTask = async () => { for await (const data of this.input) { - if (this.abortController.signal.aborted) break; + if (this.abortController.signal.aborted || closing) break; if (data === SynthesizeStream.FLUSH_SENTINEL) { sendTokenizerStream.flush(); continue; } sendTokenizerStream.pushText(data); } - sendTokenizerStream.endInput(); + // Only call endInput if the stream hasn't been closed by cleanup + if (!closing) { + sendTokenizerStream.endInput(); + } }; const createSentenceStreamTask = async () => { diff --git a/agents/src/llm/llm.ts b/agents/src/llm/llm.ts index 0c478e4c3..746eddd7c 100644 --- a/agents/src/llm/llm.ts +++ b/agents/src/llm/llm.ts @@ -8,7 +8,7 @@ import { APIConnectionError, APIError } from '../_exceptions.js'; import { log } from '../log.js'; import type { LLMMetrics } from '../metrics/base.js'; import { recordException, traceTypes, tracer } from '../telemetry/index.js'; -import type { APIConnectOptions } from '../types.js'; +import { type APIConnectOptions, intervalForRetry } from '../types.js'; import { AsyncIterableQueue, delay, startSoon, toError } from '../utils.js'; import { type ChatContext, type ChatRole, type FunctionCall } from './chat_context.js'; import type { ToolChoice, ToolContext } from './tool_context.js'; @@ -158,7 +158,7 @@ export abstract class LLMStream implements AsyncIterableIterator { ); } catch (error) { if (error instanceof APIError) { - const retryInterval = this._connOptions._intervalForRetry(i); + const retryInterval = intervalForRetry(this._connOptions, i); if (this._connOptions.maxRetry === 0 || !error.retryable) { this.emitError({ error, recoverable: false }); diff --git a/agents/src/stt/stream_adapter.ts b/agents/src/stt/stream_adapter.ts index 5d390fe35..17f29a510 100644 --- a/agents/src/stt/stream_adapter.ts +++ b/agents/src/stt/stream_adapter.ts @@ -3,6 +3,7 @@ // SPDX-License-Identifier: Apache-2.0 import type { AudioFrame } from '@livekit/rtc-node'; import { log } from '../log.js'; +import type { APIConnectOptions } from '../types.js'; import type { VAD, VADStream } from '../vad.js'; import { VADEventType } from '../vad.js'; import type { SpeechEvent } from './stt.js'; @@ -28,8 +29,8 @@ export class StreamAdapter extends STT { return this.#stt.recognize(frame); } - stream(): StreamAdapterWrapper { - return new StreamAdapterWrapper(this.#stt, this.#vad); + stream(options?: { connOptions?: APIConnectOptions }): StreamAdapterWrapper { + return new StreamAdapterWrapper(this.#stt, this.#vad, options?.connOptions); } } @@ -38,8 +39,8 @@ export class StreamAdapterWrapper extends SpeechStream { #vadStream: VADStream; label: string; - constructor(stt: STT, vad: VAD) { - super(stt); + constructor(stt: STT, vad: VAD, connOptions?: APIConnectOptions) { + super(stt, undefined, connOptions); this.#stt = stt; this.#vadStream = vad.stream(); this.label = `stt.StreamAdapterWrapper<${this.#stt.label}>`; diff --git a/agents/src/stt/stt.ts b/agents/src/stt/stt.ts index 0c0b332d8..039aa4f69 100644 --- a/agents/src/stt/stt.ts +++ b/agents/src/stt/stt.ts @@ -10,7 +10,7 @@ import { calculateAudioDurationSeconds } from '../audio.js'; import { log } from '../log.js'; import type { STTMetrics } from '../metrics/base.js'; import { DeferredReadableStream } from '../stream/deferred_stream.js'; -import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../types.js'; +import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS, intervalForRetry } from '../types.js'; import type { AudioBuffer } from '../utils.js'; import { AsyncIterableQueue, delay, startSoon, toError } from '../utils.js'; @@ -133,8 +133,10 @@ export abstract class STT extends (EventEmitter as new () => TypedEmitter { return; @@ -196,7 +198,7 @@ export abstract class SpeechStream implements AsyncIterableIterator return await this.run(); } catch (error) { if (error instanceof APIError) { - const retryInterval = this._connOptions._intervalForRetry(i); + const retryInterval = intervalForRetry(this._connOptions, i); if (this._connOptions.maxRetry === 0 || !error.retryable) { this.emitError({ error, recoverable: false }); diff --git a/agents/src/tts/stream_adapter.ts b/agents/src/tts/stream_adapter.ts index dfa24e436..10a1e59a2 100644 --- a/agents/src/tts/stream_adapter.ts +++ b/agents/src/tts/stream_adapter.ts @@ -2,6 +2,7 @@ // // SPDX-License-Identifier: Apache-2.0 import type { SentenceStream, SentenceTokenizer } from '../tokenize/index.js'; +import type { APIConnectOptions } from '../types.js'; import { Task } from '../utils.js'; import type { ChunkedStream } from './tts.js'; import { SynthesizeStream, TTS } from './tts.js'; @@ -27,8 +28,8 @@ export class StreamAdapter extends TTS { return this.#tts.synthesize(text); } - stream(): StreamAdapterWrapper { - return new StreamAdapterWrapper(this.#tts, this.#sentenceTokenizer); + stream(options?: { connOptions?: APIConnectOptions }): StreamAdapterWrapper { + return new StreamAdapterWrapper(this.#tts, this.#sentenceTokenizer, options?.connOptions); } } @@ -37,8 +38,8 @@ export class StreamAdapterWrapper extends SynthesizeStream { #sentenceStream: SentenceStream; label: string; - constructor(tts: TTS, sentenceTokenizer: SentenceTokenizer) { - super(tts); + constructor(tts: TTS, sentenceTokenizer: SentenceTokenizer, connOptions?: APIConnectOptions) { + super(tts, connOptions); this.#tts = tts; this.#sentenceStream = sentenceTokenizer.stream(); this.label = `tts.StreamAdapterWrapper<${this.#tts.label}>`; diff --git a/agents/src/tts/tts.ts b/agents/src/tts/tts.ts index db805110b..a4ced77ad 100644 --- a/agents/src/tts/tts.ts +++ b/agents/src/tts/tts.ts @@ -11,7 +11,7 @@ import { log } from '../log.js'; import type { TTSMetrics } from '../metrics/base.js'; import { DeferredReadableStream } from '../stream/deferred_stream.js'; import { recordException, traceTypes, tracer } from '../telemetry/index.js'; -import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS } from '../types.js'; +import { type APIConnectOptions, DEFAULT_API_CONNECT_OPTIONS, intervalForRetry } from '../types.js'; import { AsyncIterableQueue, delay, mergeFrames, startSoon, toError } from '../utils.js'; /** SynthesizedAudio is a packet of speech synthesis as returned by the TTS. */ @@ -94,8 +94,10 @@ export abstract class TTS extends (EventEmitter as new () => TypedEmitter { return; @@ -186,7 +188,7 @@ export abstract class SynthesizeStream ); } catch (error) { if (error instanceof APIError) { - const retryInterval = this._connOptions._intervalForRetry(i); + const retryInterval = intervalForRetry(this._connOptions, i); if (this._connOptions.maxRetry === 0 || !error.retryable) { this.emitError({ error, recoverable: false }); @@ -454,7 +456,7 @@ export abstract class ChunkedStream implements AsyncIterableIterator = {}) { - this.maxRetry = options.maxRetry ?? 3; - this.retryIntervalMs = options.retryIntervalMs ?? 2000; - this.timeoutMs = options.timeoutMs ?? 10000; +/** + * Connection options for API calls, controlling retry and timeout behavior. + */ +export interface APIConnectOptions { + /** Maximum number of retries to connect to the API. Default: 3 */ + maxRetry: number; + /** Interval between retries to connect to the API in milliseconds. Default: 2000 */ + retryIntervalMs: number; + /** Timeout for connecting to the API in milliseconds. Default: 10000 */ + timeoutMs: number; +} - if (this.maxRetry < 0) { - throw new Error('maxRetry must be greater than or equal to 0'); - } - if (this.retryIntervalMs < 0) { - throw new Error('retryIntervalMs must be greater than or equal to 0'); - } - if (this.timeoutMs < 0) { - throw new Error('timeoutMs must be greater than or equal to 0'); - } - } +export const DEFAULT_API_CONNECT_OPTIONS: APIConnectOptions = { + maxRetry: 3, + retryIntervalMs: 2000, + timeoutMs: 10000, +}; - /** @internal */ - _intervalForRetry(numRetries: number): number { - /** - * Return the interval for the given number of retries. - * - * The first retry is immediate, and then uses specified retryIntervalMs - */ - if (numRetries === 0) { - return 0.1; - } - return this.retryIntervalMs; +/** + * Return the interval for the given number of retries. + * The first retry is immediate, and then uses specified retryIntervalMs. + * @internal + */ +export function intervalForRetry(connOptions: APIConnectOptions, numRetries: number): number { + if (numRetries === 0) { + return 0.1; } + return connOptions.retryIntervalMs; +} + +/** + * Connection options for the agent session, controlling retry and timeout behavior + * for STT, LLM, and TTS connections. + */ +export interface SessionConnectOptions { + /** Connection options for speech-to-text. */ + sttConnOptions?: Partial; + /** Connection options for the language model. */ + llmConnOptions?: Partial; + /** Connection options for text-to-speech. */ + ttsConnOptions?: Partial; + /** Maximum number of consecutive unrecoverable errors from LLM or TTS before closing the session. Default: 3 */ + maxUnrecoverableErrors?: number; +} + +/** + * Resolved session connect options with all values populated. + * @internal + */ +export interface ResolvedSessionConnectOptions { + sttConnOptions: APIConnectOptions; + llmConnOptions: APIConnectOptions; + ttsConnOptions: APIConnectOptions; + maxUnrecoverableErrors: number; } -export const DEFAULT_API_CONNECT_OPTIONS = new APIConnectOptions(); +export const DEFAULT_SESSION_CONNECT_OPTIONS: ResolvedSessionConnectOptions = { + sttConnOptions: DEFAULT_API_CONNECT_OPTIONS, + llmConnOptions: DEFAULT_API_CONNECT_OPTIONS, + ttsConnOptions: DEFAULT_API_CONNECT_OPTIONS, + maxUnrecoverableErrors: 3, +}; diff --git a/agents/src/voice/agent.ts b/agents/src/voice/agent.ts index 6f529f038..10ee8a490 100644 --- a/agents/src/voice/agent.ts +++ b/agents/src/voice/agent.ts @@ -268,7 +268,8 @@ export class Agent { wrapped_stt = new STTStreamAdapter(wrapped_stt, agent.vad); } - const stream = wrapped_stt.stream(); + const connOptions = activity.agentSession.connOptions.sttConnOptions; + const stream = wrapped_stt.stream({ connOptions }); stream.updateInputStream(audio); return new ReadableStream({ @@ -304,11 +305,13 @@ export class Agent { // TODO(brian): make parallelToolCalls configurable const { toolChoice } = modelSettings; + const connOptions = activity.agentSession.connOptions.llmConnOptions; const stream = activity.llm.chat({ chatCtx, toolCtx, toolChoice, + connOptions, parallelToolCalls: true, }); return new ReadableStream({ @@ -340,7 +343,8 @@ export class Agent { wrapped_tts = new TTSStreamAdapter(wrapped_tts, new BasicSentenceTokenizer()); } - const stream = wrapped_tts.stream(); + const connOptions = activity.agentSession.connOptions.ttsConnOptions; + const stream = wrapped_tts.stream({ connOptions }); stream.updateInputStream(text); return new ReadableStream({ diff --git a/agents/src/voice/agent_session.ts b/agents/src/voice/agent_session.ts index a0618be7a..6b662e351 100644 --- a/agents/src/voice/agent_session.ts +++ b/agents/src/voice/agent_session.ts @@ -25,6 +25,12 @@ import type { STT } from '../stt/index.js'; import type { STTError } from '../stt/stt.js'; import { traceTypes, tracer } from '../telemetry/index.js'; import type { TTS, TTSError } from '../tts/tts.js'; +import { + DEFAULT_API_CONNECT_OPTIONS, + DEFAULT_SESSION_CONNECT_OPTIONS, + type ResolvedSessionConnectOptions, + type SessionConnectOptions, +} from '../types.js'; import type { VAD } from '../vad.js'; import type { Agent } from './agent.js'; import { AgentActivity } from './agent_activity.js'; @@ -100,6 +106,7 @@ export type AgentSessionOptions = { tts?: TTS | TTSModelString; userData?: UserData; voiceOptions?: Partial; + connOptions?: SessionConnectOptions; }; export class AgentSession< @@ -132,6 +139,13 @@ export class AgentSession< private closingTask: Promise | null = null; private userAwayTimer: NodeJS.Timeout | null = null; + // Connection options for STT, LLM, and TTS + private _connOptions: ResolvedSessionConnectOptions; + + // Unrecoverable error counts, reset after agent speaking + private llmErrorCounts = 0; + private ttsErrorCounts = 0; + private sessionSpan?: Span; private userSpeakingSpan?: Span; private agentSpeakingSpan?: Span; @@ -159,8 +173,19 @@ export class AgentSession< turnDetection, userData, voiceOptions = defaultVoiceOptions, + connOptions, } = opts; + // Merge user-provided connOptions with defaults + this._connOptions = { + sttConnOptions: { ...DEFAULT_API_CONNECT_OPTIONS, ...connOptions?.sttConnOptions }, + llmConnOptions: { ...DEFAULT_API_CONNECT_OPTIONS, ...connOptions?.llmConnOptions }, + ttsConnOptions: { ...DEFAULT_API_CONNECT_OPTIONS, ...connOptions?.ttsConnOptions }, + maxUnrecoverableErrors: + connOptions?.maxUnrecoverableErrors ?? + DEFAULT_SESSION_CONNECT_OPTIONS.maxUnrecoverableErrors, + }; + this.vad = vad; if (typeof stt === 'string') { @@ -225,6 +250,11 @@ export class AgentSession< return this._chatCtx; } + /** Connection options for STT, LLM, and TTS. */ + get connOptions(): ResolvedSessionConnectOptions { + return this._connOptions; + } + set userData(value: UserData) { this._userData = value; } @@ -514,6 +544,19 @@ export class AgentSession< return; } + // Track error counts per type to implement max_unrecoverable_errors logic + if (error.type === 'llm_error') { + this.llmErrorCounts += 1; + if (this.llmErrorCounts <= this._connOptions.maxUnrecoverableErrors) { + return; + } + } else if (error.type === 'tts_error') { + this.ttsErrorCounts += 1; + if (this.ttsErrorCounts <= this._connOptions.maxUnrecoverableErrors) { + return; + } + } + this.logger.error(error, 'AgentSession is closing due to unrecoverable error'); this.closingTask = (async () => { @@ -541,7 +584,9 @@ export class AgentSession< } if (state === 'speaking') { - // TODO(brian): PR4 - Track error counts + // Reset error counts when agent starts speaking + this.llmErrorCounts = 0; + this.ttsErrorCounts = 0; if (this.agentSpeakingSpan === undefined) { this.agentSpeakingSpan = tracer.startSpan({ @@ -730,6 +775,8 @@ export class AgentSession< this.userState = 'listening'; this._agentState = 'initializing'; this.rootSpanContext = undefined; + this.llmErrorCounts = 0; + this.ttsErrorCounts = 0; this.logger.info({ reason, error }, 'AgentSession closed'); } diff --git a/agents/src/voice/generation.ts b/agents/src/voice/generation.ts index d065b28d0..82777ce12 100644 --- a/agents/src/voice/generation.ts +++ b/agents/src/voice/generation.ts @@ -24,7 +24,7 @@ import { isZodSchema, parseZodSchema } from '../llm/zod-utils.js'; import { log } from '../log.js'; import { IdentityTransform } from '../stream/identity_transform.js'; import { traceTypes, tracer } from '../telemetry/index.js'; -import { Future, Task, shortuuid, toError } from '../utils.js'; +import { Future, Task, shortuuid, toError, waitForAbort } from '../utils.js'; import { type Agent, type ModelSettings, asyncLocalStorage, isStopResponse } from './agent.js'; import type { AgentSession } from './agent_session.js'; import type { AudioOutput, LLMNode, TTSNode, TextOutput } from './io.js'; @@ -411,17 +411,19 @@ export function performLLMInference( return; } + const abortPromise = waitForAbort(signal); + // TODO(brian): add support for dynamic tools llmStreamReader = llmStream.getReader(); while (true) { - if (signal.aborted) { - break; - } - const { done, value: chunk } = await llmStreamReader.read(); - if (done) { - break; - } + if (signal.aborted) break; + + const result = await Promise.race([llmStreamReader.read(), abortPromise]); + if (result === undefined) break; + + const { done, value: chunk } = result; + if (done) break; if (typeof chunk === 'string') { data.generatedText += chunk; diff --git a/examples/src/basic_agent.ts b/examples/src/basic_agent.ts index a4cda7a3b..6228c056b 100644 --- a/examples/src/basic_agent.ts +++ b/examples/src/basic_agent.ts @@ -61,6 +61,14 @@ export default defineAgent({ // allow the LLM to generate a response while waiting for the end of turn preemptiveGeneration: true, }, + connOptions: { + // Example of overriding the default connection options for the LLM/TTS/STT + llmConnOptions: { + maxRetry: 1, + retryIntervalMs: 2000, + timeoutMs: 60000, + }, + }, }); const usageCollector = new metrics.UsageCollector(); diff --git a/plugins/bey/src/avatar.ts b/plugins/bey/src/avatar.ts index 3393741d3..ab1a94908 100644 --- a/plugins/bey/src/avatar.ts +++ b/plugins/bey/src/avatar.ts @@ -7,6 +7,7 @@ import { APIStatusError, DEFAULT_API_CONNECT_OPTIONS, getJobContext, + intervalForRetry, voice, } from '@livekit/agents'; import type { Room } from '@livekit/rtc-node'; @@ -232,7 +233,7 @@ export class AvatarSession { if (i < this.connOptions.maxRetry - 1) { await new Promise((resolve) => - setTimeout(resolve, this.connOptions._intervalForRetry(i)), + setTimeout(resolve, intervalForRetry(this.connOptions, i)), ); } }