2121import os
2222import weakref
2323from dataclasses import dataclass
24+ from typing import Literal
2425from urllib .parse import urlencode
2526
2627import aiohttp
4546class STTOptions :
4647 sample_rate : int
4748 buffer_size_seconds : float
48- encoding : str = "pcm_s16le"
49+ encoding : Literal ["pcm_s16le" , "pcm_mulaw" ] = "pcm_s16le"
50+ speech_model : Literal ["universal-streaming-english" , "universal-streaming-multilingual" ] = (
51+ "universal-streaming-english"
52+ )
4953 end_of_turn_confidence_threshold : NotGivenOr [float ] = NOT_GIVEN
5054 min_end_of_turn_silence_when_confident : NotGivenOr [int ] = NOT_GIVEN
5155 max_turn_silence : NotGivenOr [int ] = NOT_GIVEN
@@ -59,7 +63,10 @@ def __init__(
5963 * ,
6064 api_key : NotGivenOr [str ] = NOT_GIVEN ,
6165 sample_rate : int = 16000 ,
62- encoding : str = "pcm_s16le" ,
66+ encoding : Literal ["pcm_s16le" , "pcm_mulaw" ] = "pcm_s16le" ,
67+ model : Literal [
68+ "universal-streaming-english" , "universal-streaming-multilingual"
69+ ] = "universal-streaming-english" ,
6370 end_of_turn_confidence_threshold : NotGivenOr [float ] = NOT_GIVEN ,
6471 min_end_of_turn_silence_when_confident : NotGivenOr [int ] = NOT_GIVEN ,
6572 max_turn_silence : NotGivenOr [int ] = NOT_GIVEN ,
@@ -83,6 +90,7 @@ def __init__(
8390 sample_rate = sample_rate ,
8491 buffer_size_seconds = buffer_size_seconds ,
8592 encoding = encoding ,
93+ speech_model = model ,
8694 end_of_turn_confidence_threshold = end_of_turn_confidence_threshold ,
8795 min_end_of_turn_silence_when_confident = min_end_of_turn_silence_when_confident ,
8896 max_turn_silence = max_turn_silence ,
@@ -301,6 +309,7 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
301309 live_config = {
302310 "sample_rate" : self ._opts .sample_rate ,
303311 "encoding" : self ._opts .encoding ,
312+ "speech_model" : self ._opts .speech_model ,
304313 "format_turns" : self ._opts .format_turns if is_given (self ._opts .format_turns ) else None ,
305314 "end_of_turn_confidence_threshold" : self ._opts .end_of_turn_confidence_threshold
306315 if is_given (self ._opts .end_of_turn_confidence_threshold )
@@ -335,23 +344,37 @@ async def _connect_ws(self) -> aiohttp.ClientWebSocketResponse:
335344 def _process_stream_event (self , data : dict ) -> None :
336345 message_type = data .get ("type" )
337346 if message_type == "Turn" :
338- transcript = data .get ("transcript" )
339347 words = data .get ("words" , [])
340- end_of_turn = data .get ("end_of_turn" )
341-
342- if transcript and end_of_turn :
343- turn_is_formatted = data .get ("turn_is_formatted" , False )
344- if not self ._opts .format_turns or (self ._opts .format_turns and turn_is_formatted ):
345- final_event = stt .SpeechEvent (
346- type = stt .SpeechEventType .FINAL_TRANSCRIPT ,
347- # TODO: We can't know the language?
348- alternatives = [stt .SpeechData (language = "en-US" , text = transcript )],
349- )
350- else :
351- # skip emitting final transcript if format_turns is enabled but this
352- # turn isn't formatted
353- return
348+ end_of_turn = data .get ("end_of_turn" , False )
349+ turn_is_formatted = data .get ("turn_is_formatted" , False )
350+ utterance = data .get ("utterance" , "" )
351+ transcript = data .get ("transcript" , "" )
352+
353+ if words :
354+ interim_text = " " .join (word .get ("text" , "" ) for word in words )
355+ interim_event = stt .SpeechEvent (
356+ type = stt .SpeechEventType .INTERIM_TRANSCRIPT ,
357+ alternatives = [stt .SpeechData (language = "en" , text = interim_text )],
358+ )
359+ self ._event_ch .send_nowait (interim_event )
360+
361+ if utterance :
362+ final_event = stt .SpeechEvent (
363+ type = stt .SpeechEventType .PREFLIGHT_TRANSCRIPT ,
364+ alternatives = [stt .SpeechData (language = "en" , text = utterance )],
365+ )
354366 self ._event_ch .send_nowait (final_event )
367+
368+ if end_of_turn and (
369+ not (is_given (self ._opts .format_turns ) and self ._opts .format_turns )
370+ or turn_is_formatted
371+ ):
372+ final_event = stt .SpeechEvent (
373+ type = stt .SpeechEventType .FINAL_TRANSCRIPT ,
374+ alternatives = [stt .SpeechData (language = "en" , text = transcript )],
375+ )
376+ self ._event_ch .send_nowait (final_event )
377+
355378 self ._event_ch .send_nowait (stt .SpeechEvent (type = stt .SpeechEventType .END_OF_SPEECH ))
356379
357380 if self ._speech_duration > 0.0 :
@@ -364,12 +387,3 @@ def _process_stream_event(self, data: dict) -> None:
364387 )
365388 self ._event_ch .send_nowait (usage_event )
366389 self ._speech_duration = 0
367-
368- else :
369- non_final_words = [word ["text" ] for word in words if not word ["word_is_final" ]]
370- interim = " " .join (non_final_words )
371- interim_event = stt .SpeechEvent (
372- type = stt .SpeechEventType .INTERIM_TRANSCRIPT ,
373- alternatives = [stt .SpeechData (language = "en-US" , text = f"{ transcript } { interim } " )],
374- )
375- self ._event_ch .send_nowait (interim_event )
0 commit comments