Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import os
import weakref
from dataclasses import dataclass
from typing import Literal
from urllib.parse import urlencode

import aiohttp
Expand All @@ -45,7 +46,10 @@
class STTOptions:
sample_rate: int
buffer_size_seconds: float
encoding: str = "pcm_s16le"
encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le"
speech_model: Literal["universal-streaming-english", "universal-streaming-multilingual"] = (
"universal-streaming-english"
)
end_of_turn_confidence_threshold: NotGivenOr[float] = NOT_GIVEN
min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN
max_turn_silence: NotGivenOr[int] = NOT_GIVEN
Expand All @@ -59,7 +63,10 @@ def __init__(
*,
api_key: NotGivenOr[str] = NOT_GIVEN,
sample_rate: int = 16000,
encoding: str = "pcm_s16le",
encoding: Literal["pcm_s16le", "pcm_mulaw"] = "pcm_s16le",
model: Literal[
"universal-streaming-english", "universal-streaming-multilingual"
] = "universal-streaming-english",
end_of_turn_confidence_threshold: NotGivenOr[float] = NOT_GIVEN,
min_end_of_turn_silence_when_confident: NotGivenOr[int] = NOT_GIVEN,
max_turn_silence: NotGivenOr[int] = NOT_GIVEN,
Expand All @@ -83,6 +90,7 @@ def __init__(
sample_rate=sample_rate,
buffer_size_seconds=buffer_size_seconds,
encoding=encoding,
speech_model=model,
end_of_turn_confidence_threshold=end_of_turn_confidence_threshold,
min_end_of_turn_silence_when_confident=min_end_of_turn_silence_when_confident,
max_turn_silence=max_turn_silence,
Expand Down Expand Up @@ -301,6 +309,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
live_config = {
"sample_rate": self._opts.sample_rate,
"encoding": self._opts.encoding,
"speech_model": self._opts.speech_model,
"format_turns": self._opts.format_turns if is_given(self._opts.format_turns) else None,
"end_of_turn_confidence_threshold": self._opts.end_of_turn_confidence_threshold
if is_given(self._opts.end_of_turn_confidence_threshold)
Expand Down Expand Up @@ -335,23 +344,37 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
def _process_stream_event(self, data: dict) -> None:
message_type = data.get("type")
if message_type == "Turn":
transcript = data.get("transcript")
words = data.get("words", [])
end_of_turn = data.get("end_of_turn")

if transcript and end_of_turn:
turn_is_formatted = data.get("turn_is_formatted", False)
if not self._opts.format_turns or (self._opts.format_turns and turn_is_formatted):
final_event = stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
# TODO: We can't know the language?
alternatives=[stt.SpeechData(language="en-US", text=transcript)],
)
else:
# skip emitting final transcript if format_turns is enabled but this
# turn isn't formatted
return
end_of_turn = data.get("end_of_turn", False)
turn_is_formatted = data.get("turn_is_formatted", False)
utterance = data.get("utterance", "")
transcript = data.get("transcript", "")

if words:
interim_text = " ".join(word.get("text", "") for word in words)
interim_event = stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[stt.SpeechData(language="en", text=interim_text)],
)
self._event_ch.send_nowait(interim_event)

if utterance:
final_event = stt.SpeechEvent(
type=stt.SpeechEventType.PREFLIGHT_TRANSCRIPT,
alternatives=[stt.SpeechData(language="en", text=utterance)],
)
self._event_ch.send_nowait(final_event)

if end_of_turn and (
not (is_given(self._opts.format_turns) and self._opts.format_turns)
or turn_is_formatted
):
final_event = stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[stt.SpeechData(language="en", text=transcript)],
)
self._event_ch.send_nowait(final_event)

self._event_ch.send_nowait(stt.SpeechEvent(type=stt.SpeechEventType.END_OF_SPEECH))

if self._speech_duration > 0.0:
Expand All @@ -364,12 +387,3 @@ def _process_stream_event(self, data: dict) -> None:
)
self._event_ch.send_nowait(usage_event)
self._speech_duration = 0

else:
non_final_words = [word["text"] for word in words if not word["word_is_final"]]
interim = " ".join(non_final_words)
interim_event = stt.SpeechEvent(
type=stt.SpeechEventType.INTERIM_TRANSCRIPT,
alternatives=[stt.SpeechData(language="en-US", text=f"{transcript} {interim}")],
)
self._event_ch.send_nowait(interim_event)