diff --git a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py index 9ead487f26..239389ec8e 100644 --- a/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py +++ b/livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py @@ -21,6 +21,7 @@ import os import weakref from dataclasses import dataclass +from typing import Literal from urllib.parse import urlencode import aiohttp @@ -45,7 +46,10 @@ class STTOptions: sample_rate: int buffer_size_seconds: float - encoding: str = "pcm_s16le" + encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le" + speech_model: Literal["universal-streaming-english", "universal-streaming-multilingual"] = ( + "universal-streaming-english" + ) end_of_turn_confidence_threshold: NotGivenOr[float] = NOT_GIVEN min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN max_turn_silence: NotGivenOr[int] = NOT_GIVEN @@ -59,7 +63,10 @@ def __init__( *, api_key: NotGivenOr[str] = NOT_GIVEN, sample_rate: int = 16000, - encoding: str = "pcm_s16le", + encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le", + model: Literal[ + "universal-streaming-english", "universal-streaming-multilingual" + ] = "universal-streaming-english", end_of_turn_confidence_threshold: NotGivenOr[float] = NOT_GIVEN, min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN, max_turn_silence: NotGivenOr[int] = NOT_GIVEN, @@ -83,6 +90,7 @@ def __init__( sample_rate=sample_rate, buffer_size_seconds=buffer_size_seconds, encoding=encoding, + speech_model=model, end_of_turn_confidence_threshold=end_of_turn_confidence_threshold, min_end_of_turn_silence_when_confident=min_end_of_turn_silence_when_confident, max_turn_silence=max_turn_silence, @@ -301,6 +309,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: live_config = { "sample_rate": self._opts.sample_rate, "encoding": self._opts.encoding, + "speech_model": self._opts.speech_model, "format_turns": self._opts.format_turns if is_given(self._opts.format_turns) else None, "end_of_turn_confidence_threshold": self._opts.end_of_turn_confidence_threshold if is_given(self._opts.end_of_turn_confidence_threshold) @@ -335,23 +344,37 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse: def _process_stream_event(self, data: dict) -> None: message_type = data.get("type") if message_type == "Turn": - transcript = data.get("transcript") words = data.get("words", []) - end_of_turn = data.get("end_of_turn") - - if transcript and end_of_turn: - turn_is_formatted = data.get("turn_is_formatted", False) - if not self._opts.format_turns or (self._opts.format_turns and turn_is_formatted): - final_event = stt.SpeechEvent( - type=stt.SpeechEventType.FINAL_TRANSCRIPT, - # TODO: We can't know the language? - alternatives=[stt.SpeechData(language="en-US", text=transcript)], - ) - else: - # skip emitting final transcript if format_turns is enabled but this - # turn isn't formatted - return + end_of_turn = data.get("end_of_turn", False) + turn_is_formatted = data.get("turn_is_formatted", False) + utterance = data.get("utterance", "") + transcript = data.get("transcript", "") + + if words: + interim_text = " ".join(word.get("text", "") for word in words) + interim_event = stt.SpeechEvent( + type=stt.SpeechEventType.INTERIM_TRANSCRIPT, + alternatives=[stt.SpeechData(language="en", text=interim_text)], + ) + self._event_ch.send_nowait(interim_event) + + if utterance: + final_event = stt.SpeechEvent( + type=stt.SpeechEventType.PREFLIGHT_TRANSCRIPT, + alternatives=[stt.SpeechData(language="en", text=utterance)], + ) self._event_ch.send_nowait(final_event) + + if end_of_turn and ( + not (is_given(self._opts.format_turns) and self._opts.format_turns) + or turn_is_formatted + ): + final_event = stt.SpeechEvent( + type=stt.SpeechEventType.FINAL_TRANSCRIPT, + alternatives=[stt.SpeechData(language="en", text=transcript)], + ) + self._event_ch.send_nowait(final_event) + self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH)) if self._speech_duration > 0.0: @@ -364,12 +387,3 @@ def _process_stream_event(self, data: dict) -> None: ) self._event_ch.send_nowait(usage_event) self._speech_duration = 0 - - else: - non_final_words = [word["text"] for word in words if not word["word_is_final"]] - interim = " ".join(non_final_words) - interim_event = stt.SpeechEvent( - type=stt.SpeechEventType.INTERIM_TRANSCRIPT, - alternatives=[stt.SpeechData(language="en-US", text=f"{transcript} {interim}")], - ) - self._event_ch.send_nowait(interim_event)