Skip to content

Commit 1494a8b

Browse files
authored
feat: add preflight transcript via utterance (#3654)
1 parent dada245 commit 1494a8b

File tree

1 file changed

+40
-26
lines changed
  • livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai

1 file changed

+40
-26
lines changed

livekit-plugins/livekit-plugins-assemblyai/livekit/plugins/assemblyai/stt.py

Lines changed: 40 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import os
2222
import weakref
2323
from dataclasses import dataclass
24+
from typing import Literal
2425
from urllib.parse import urlencode
2526

2627
import aiohttp
@@ -45,7 +46,10 @@
4546
class STTOptions:
4647
sample_rate: int
4748
buffer_size_seconds: float
48-
encoding: str = "pcm_s16le"
49+
encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le"
50+
speech_model: Literal["universal-streaming-english", "universal-streaming-multilingual"] = (
51+
"universal-streaming-english"
52+
)
4953
end_of_turn_confidence_threshold: NotGivenOr[float] = NOT_GIVEN
5054
min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN
5155
max_turn_silence: NotGivenOr[int] = NOT_GIVEN
@@ -59,7 +63,10 @@ def __init__(
5963
*,
6064
api_key: NotGivenOr[str] = NOT_GIVEN,
6165
sample_rate: int = 16000,
62-
encoding: str = "pcm_s16le",
66+
encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le",
67+
model: Literal[
68+
"universal-streaming-english", "universal-streaming-multilingual"
69+
] = "universal-streaming-english",
6370
end_of_turn_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
6471
min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN,
6572
max_turn_silence: NotGivenOr[int] = NOT_GIVEN,
@@ -83,6 +90,7 @@ def __init__(
8390
sample_rate=sample_rate,
8491
buffer_size_seconds=buffer_size_seconds,
8592
encoding=encoding,
93+
speech_model=model,
8694
end_of_turn_confidence_threshold=end_of_turn_confidence_threshold,
8795
min_end_of_turn_silence_when_confident=min_end_of_turn_silence_when_confident,
8896
max_turn_silence=max_turn_silence,
@@ -301,6 +309,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
301309
live_config = {
302310
"sample_rate": self._opts.sample_rate,
303311
"encoding": self._opts.encoding,
312+
"speech_model": self._opts.speech_model,
304313
"format_turns": self._opts.format_turns if is_given(self._opts.format_turns) else None,
305314
"end_of_turn_confidence_threshold": self._opts.end_of_turn_confidence_threshold
306315
if is_given(self._opts.end_of_turn_confidence_threshold)
@@ -335,23 +344,37 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
335344
def _process_stream_event(self, data: dict) -> None:
336345
message_type = data.get("type")
337346
if message_type == "Turn":
338-
transcript = data.get("transcript")
339347
words = data.get("words", [])
340-
end_of_turn = data.get("end_of_turn")
341-
342-
if transcript and end_of_turn:
343-
turn_is_formatted = data.get("turn_is_formatted", False)
344-
if not self._opts.format_turns or (self._opts.format_turns and turn_is_formatted):
345-
final_event = stt.SpeechEvent(
346-
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
347-
# TODO: We can't know the language?
348-
alternatives=[stt.SpeechData(language="en-US", text=transcript)],
349-
)
350-
else:
351-
# skip emitting final transcript if format_turns is enabled but this
352-
# turn isn't formatted
353-
return
348+
end_of_turn = data.get("end_of_turn", False)
349+
turn_is_formatted = data.get("turn_is_formatted", False)
350+
utterance = data.get("utterance", "")
351+
transcript = data.get("transcript", "")
352+
353+
if words:
354+
interim_text = " ".join(word.get("text", "") for word in words)
355+
interim_event = stt.SpeechEvent(
356+
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
357+
alternatives=[stt.SpeechData(language="en", text=interim_text)],
358+
)
359+
self._event_ch.send_nowait(interim_event)
360+
361+
if utterance:
362+
final_event = stt.SpeechEvent(
363+
type=stt.SpeechEventType.PREFLIGHT_TRANSCRIPT,
364+
alternatives=[stt.SpeechData(language="en", text=utterance)],
365+
)
354366
self._event_ch.send_nowait(final_event)
367+
368+
if end_of_turn and (
369+
not (is_given(self._opts.format_turns) and self._opts.format_turns)
370+
or turn_is_formatted
371+
):
372+
final_event = stt.SpeechEvent(
373+
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
374+
alternatives=[stt.SpeechData(language="en", text=transcript)],
375+
)
376+
self._event_ch.send_nowait(final_event)
377+
355378
self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))
356379

357380
if self._speech_duration > 0.0:
@@ -364,12 +387,3 @@ def _process_stream_event(self, data: dict) -> None:
364387
)
365388
self._event_ch.send_nowait(usage_event)
366389
self._speech_duration = 0
367-
368-
else:
369-
non_final_words = [word["text"] for word in words if not word["word_is_final"]]
370-
interim = " ".join(non_final_words)
371-
interim_event = stt.SpeechEvent(
372-
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
373-
alternatives=[stt.SpeechData(language="en-US", text=f"{transcript} {interim}")],
374-
)
375-
self._event_ch.send_nowait(interim_event)

0 commit comments

Comments
 (0)