Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def __init__(self, *, tts: TTS, conn_options: APIConnectOptions):
super().__init__(tts=tts, conn_options=conn_options)
self._tts: TTS = tts
self._opts = replace(tts._opts)
self._segments_ch = utils.aio.Chan[tokenize.WordStream]()
self._segments_ch = utils.aio.Chan[tokenize.SentenceStream]()

async def _run(self, output_emitter: tts.AudioEmitter) -> None:
request_id = utils.shortuuid()
Expand All @@ -314,23 +314,23 @@ async def _run(self, output_emitter: tts.AudioEmitter) -> None:
)

async def _tokenize_input() -> None:
word_stream = None
chunks_stream = None
async for input in self._input_ch:
if isinstance(input, str):
if word_stream is None:
word_stream = self._opts.word_tokenizer.stream()
self._segments_ch.send_nowait(word_stream)
word_stream.push_text(input)
if chunks_stream is None:
chunks_stream = self._tts._sentence_tokenizer.stream()
self._segments_ch.send_nowait(chunks_stream)
chunks_stream.push_text(input)
elif isinstance(input, self._FlushSentinel):
if word_stream:
word_stream.end_input()
word_stream = None
if chunks_stream:
chunks_stream.end_input()
chunks_stream = None

self._segments_ch.close()

async def _run_segments() -> None:
async for word_stream in self._segments_ch:
await self._run_ws(word_stream, output_emitter)
async for chunk_stream in self._segments_ch:
await self._run_ws(chunk_stream, output_emitter)

tasks = [
asyncio.create_task(_tokenize_input()),
Expand All @@ -353,19 +353,21 @@ async def _run_segments() -> None:
await utils.aio.gracefully_cancel(*tasks)

async def _run_ws(
self, word_stream: tokenize.WordStream, output_emitter: tts.AudioEmitter
self, chunks_stream: tokenize.SentenceStream, output_emitter: tts.AudioEmitter
) -> None:
segment_id = utils.shortuuid()
output_emitter.start_segment(segment_id=segment_id)
chunks = 0

async def send_task(ws: aiohttp.ClientWebSocketResponse) -> None:
async for word in word_stream:
text_msg = {"text": f"{word.token} "}
async for sentence in chunks_stream:
self._mark_started()
await ws.send_str(json.dumps(text_msg))

stop_msg = {"text": "<STOP>"}
await ws.send_str(json.dumps(stop_msg))
nonlocal chunks
chunks += 1

msg = {"text": f"{sentence.token}<STOP>", "context_id": segment_id}
await ws.send_str(json.dumps(msg))

async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:
while True:
Expand All @@ -390,18 +392,21 @@ async def recv_task(ws: aiohttp.ClientWebSocketResponse) -> None:

data = resp.get("data", {})
audio_data = data.get("audio")
if audio_data and audio_data != "":
if audio_data and audio_data != "" and data.get("context_id") == segment_id:
try:
b64data = base64.b64decode(audio_data)
if b64data:
output_emitter.push(b64data)
except Exception as e:
logger.warning("Failed to decode NeuPhonic audio data: %s", e)

nonlocal chunks
if data.get("stop"):
chunks -= 1

if data.get("context_id") != segment_id or chunks == 0:
output_emitter.end_segment()
break

elif msg.type == aiohttp.WSMsgType.BINARY:
pass
else:
Expand Down