diff --git a/livekit-agents/livekit/agents/voice/agent_activity.py b/livekit-agents/livekit/agents/voice/agent_activity.py index 182eec7849..40cb436db8 100644 --- a/livekit-agents/livekit/agents/voice/agent_activity.py +++ b/livekit-agents/livekit/agents/voice/agent_activity.py @@ -1391,8 +1391,10 @@ async def _user_turn_completed_task( if preemptive := self._preemptive_generation: # make sure the on_user_turn_completed didn't change some request parameters # otherwise invalidate the preemptive generation + if ( - preemptive.info.new_transcript == user_message.text_content + (preemptive.info.new_transcript or "").lower() + == (user_message.text_content or "").lower() and preemptive.chat_ctx.is_equivalent(temp_mutable_chat_ctx) and preemptive.tools == self.tools and preemptive.tool_choice == self._tool_choice diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py index 239389ec8e..d5bc898941 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py @@ -21,7 +21,7 @@ import os import weakref from dataclasses import dataclass -from typing import Literal +from typing import Any, Literal from urllib.parse import urlencode import aiohttp @@ -147,6 +147,7 @@ def update_options( end_of_turn_confidence_threshold: NotGivenOr[float] = NOT_GIVEN, min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN, max_turn_silence: NotGivenOr[int] = NOT_GIVEN, + keyterms_prompt: NotGivenOr[list[str]] = NOT_GIVEN, ) -> None: if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds @@ -158,6 +159,8 @@ def update_options( ) if is_given(max_turn_silence): self._opts.max_turn_silence = max_turn_silence + if is_given(keyterms_prompt): + self._opts.keyterms_prompt = keyterms_prompt for stream in self._streams: stream.update_options( @@ -165,6 +168,7 @@ def update_options( end_of_turn_confidence_threshold=end_of_turn_confidence_threshold, min_end_of_turn_silence_when_confident=min_end_of_turn_silence_when_confident, max_turn_silence=max_turn_silence, + keyterms_prompt=keyterms_prompt, ) @@ -188,6 +192,8 @@ def __init__( self._session = http_session self._speech_duration: float = 0 self._reconnect_event = asyncio.Event() + self._ws: aiohttp.ClientWebSocketResponse | None = None + self._current_language_code: str = "en" # Track language code from utterances def update_options( self, @@ -196,19 +202,45 @@ def update_options( end_of_turn_confidence_threshold: NotGivenOr[float] = NOT_GIVEN, min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN, max_turn_silence: NotGivenOr[int] = NOT_GIVEN, + keyterms_prompt: NotGivenOr[list[str]] = NOT_GIVEN, ) -> None: + # Build UpdateConfiguration message for dynamic updates + update_config: dict[str, Any] = {"type": "UpdateConfiguration"} + needs_update = False + if is_given(buffer_size_seconds): self._opts.buffer_size_seconds = buffer_size_seconds if is_given(end_of_turn_confidence_threshold): self._opts.end_of_turn_confidence_threshold = end_of_turn_confidence_threshold + update_config["end_of_turn_confidence_threshold"] = end_of_turn_confidence_threshold + needs_update = True if is_given(min_end_of_turn_silence_when_confident): self._opts.min_end_of_turn_silence_when_confident = ( min_end_of_turn_silence_when_confident ) + update_config["min_end_of_turn_silence_when_confident"] = ( + min_end_of_turn_silence_when_confident + ) + needs_update = True if is_given(max_turn_silence): self._opts.max_turn_silence = max_turn_silence - - self._reconnect_event.set() + update_config["max_turn_silence"] = max_turn_silence + needs_update = True + if is_given(keyterms_prompt): + self._opts.keyterms_prompt = keyterms_prompt + update_config["keyterms_prompt"] = keyterms_prompt + needs_update = True + + # Send UpdateConfiguration message to active websocket if available + if needs_update and self._ws is not None and not self._ws.closed: + update_msg = json.dumps(update_config) + asyncio.create_task(self._ws.send_str(update_msg)) + logger.debug(f"Sent UpdateConfiguration: {update_msg}") + return # Don't trigger reconnection for dynamic updates + + # Only trigger reconnection if buffer_size_seconds changed (requires reconnect) + if is_given(buffer_size_seconds): + self._reconnect_event.set() async def _run(self) -> None: """ @@ -280,6 +312,7 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: while True: try: ws = await self._connect_ws() + self._ws = ws # Store reference for dynamic updates tasks = [ asyncio.create_task(send_task(ws)), asyncio.create_task(recv_task(ws)), @@ -304,11 +337,13 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None: finally: if ws is not None: await ws.close() + self._ws = None # Clear reference async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { "sample_rate": self._opts.sample_rate, "encoding": self._opts.encoding, + "language_detection": True, "speech_model": self._opts.speech_model, "format_turns": self._opts.format_turns if is_given(self._opts.format_turns) else None, "end_of_turn_confidence_threshold": self._opts.end_of_turn_confidence_threshold @@ -349,19 +384,32 @@ def _process_stream_event(self, data: dict) -> None: turn_is_formatted = data.get("turn_is_formatted", False) utterance = data.get("utterance", "") transcript = data.get("transcript", "") + confidence = words[-1].get("confidence", 0.0) if words else 0.0 + + # language_code is only returned with utterances, so track it for final transcript + if "language_code" in data: + self._current_language_code = data["language_code"] if words: interim_text = " ".join(word.get("text", "") for word in words) interim_event = stt.SpeechEvent( type=stt.SpeechEventType.INTERIM_TRANSCRIPT, - alternatives=[stt.SpeechData(language="en", text=interim_text)], + alternatives=[ + stt.SpeechData(language="en", text=interim_text, confidence=confidence) + ], ) self._event_ch.send_nowait(interim_event) if utterance: final_event = stt.SpeechEvent( type=stt.SpeechEventType.PREFLIGHT_TRANSCRIPT, - alternatives=[stt.SpeechData(language="en", text=utterance)], + alternatives=[ + stt.SpeechData( + language=self._current_language_code, + text=utterance, + confidence=confidence, + ) + ], ) self._event_ch.send_nowait(final_event) @@ -371,7 +419,13 @@ def _process_stream_event(self, data: dict) -> None: ): final_event = stt.SpeechEvent( type=stt.SpeechEventType.FINAL_TRANSCRIPT, - alternatives=[stt.SpeechData(language="en", text=transcript)], + alternatives=[ + stt.SpeechData( + language=self._current_language_code, + text=transcript, + confidence=confidence, + ) + ], ) self._event_ch.send_nowait(final_event)