Skip to content

Commit 918b1af

Browse files
authored
bidi audio io - handle interruption (strands-agents#45)
1 parent 8dcee5d commit 918b1af

File tree

7 files changed

+235
-122
lines changed

7 files changed

+235
-122
lines changed

src/strands/experimental/bidi/agent/agent.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -428,15 +428,15 @@ async def run_outputs():
428428
for output in outputs:
429429
if hasattr(output, "start"):
430430
await output.start()
431-
431+
432432
# Start agent after all IO is ready
433433
await self.start()
434434
try:
435435
await asyncio.gather(run_inputs(), run_outputs(), return_exceptions=True)
436436

437437
finally:
438438
await self.stop()
439-
439+
440440
for input_ in inputs:
441441
if hasattr(input_, "stop"):
442442
await input_.stop()

src/strands/experimental/bidi/agent/loop.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import logging
88
from typing import AsyncIterable, Awaitable, TYPE_CHECKING
99

10-
from ..types.events import BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent, BidiTranscriptStreamEvent
10+
from ..types.events import BidiOutputEvent, BidiTranscriptStreamEvent
1111
from ....types._events import ToolResultEvent, ToolResultMessageEvent, ToolStreamEvent, ToolUseStreamEvent
1212
from ....types.content import Message
1313
from ....types.tools import ToolResult, ToolUse
@@ -117,13 +117,6 @@ async def _run_model(self) -> None:
117117
elif isinstance(event, ToolUseStreamEvent):
118118
self._create_task(self._run_tool(event["current_tool_use"]))
119119

120-
elif isinstance(event, BidiInterruptionEvent):
121-
# clear the audio
122-
for _ in range(self._event_queue.qsize()):
123-
event = self._event_queue.get_nowait()
124-
if not isinstance(event, BidiAudioStreamEvent):
125-
self._event_queue.put_nowait(event)
126-
127120
async def _run_tool(self, tool_use: ToolUse) -> None:
128121
"""Task for running tool requested by the model."""
129122
logger.debug("running tool")
Lines changed: 150 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -1,161 +1,201 @@
1-
"""AudioIO - Clean separation of audio functionality from core BidiAgent.
1+
"""Send and receive audio data from devices.
22
3-
Provides audio input/output capabilities for BidiAgent through the BidiIO protocol.
4-
Handles all PyAudio setup, streaming, and cleanup while keeping the core agent data-agnostic.
3+
Reads user audio from input device and sends agent audio to output device using PyAudio. If a user interrupts the agent,
4+
the output buffer is cleared to stop playback.
55
"""
66

77
import asyncio
88
import base64
99
import logging
10+
from collections import deque
11+
from typing import Any
1012

1113
import pyaudio
1214

1315
from ..types.io import BidiInput, BidiOutput
14-
from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiOutputEvent
16+
from ..types.events import BidiAudioInputEvent, BidiAudioStreamEvent, BidiInterruptionEvent, BidiOutputEvent
1517

1618
logger = logging.getLogger(__name__)
1719

1820

1921
class _BidiAudioInput(BidiInput):
20-
"""Handle audio input from bidi agent."""
21-
def __init__(self, audio: "BidiAudioIO") -> None:
22-
"""Store reference to pyaudio instance."""
23-
self.audio = audio
24-
22+
"""Handle audio input from user.
23+
24+
Attributes:
25+
_audio: PyAudio instance for audio system access.
26+
_stream: Audio input stream.
27+
"""
28+
29+
_audio: pyaudio.PyAudio
30+
_stream: pyaudio.Stream
31+
32+
_CHANNELS: int = 1
33+
_DEVICE_INDEX: int | None = None
34+
_ENCODING: str = "pcm"
35+
_FORMAT: int = pyaudio.paInt16
36+
_FRAMES_PER_BUFFER: int = 512
37+
_RATE: int = 16000
38+
39+
def __init__(self, config: dict[str, Any]) -> None:
40+
"""Extract configs."""
41+
self._channels = config.get("input_channels", _BidiAudioInput._CHANNELS)
42+
self._device_index = config.get("input_device_index", _BidiAudioInput._DEVICE_INDEX)
43+
self._format = config.get("input_format", _BidiAudioInput._FORMAT)
44+
self._frames_per_buffer = config.get("input_frames_per_buffer", _BidiAudioInput._FRAMES_PER_BUFFER)
45+
self._rate = config.get("input_rate", _BidiAudioInput._RATE)
46+
2547
async def start(self) -> None:
26-
"""Start audio input."""
27-
self.audio._start()
48+
"""Start input stream."""
49+
self._audio = pyaudio.PyAudio()
50+
self._stream = self._audio.open(
51+
channels=self._channels,
52+
format=self._format,
53+
frames_per_buffer=self._frames_per_buffer,
54+
input=True,
55+
input_device_index=self._device_index,
56+
rate=self._rate,
57+
)
2858

2959
async def stop(self) -> None:
30-
"""Stop audio input."""
31-
self.audio._stop()
60+
"""Stop input stream."""
61+
# TODO: Provide time for streaming thread to exit cleanly to prevent conflicts with the Nova threads.
62+
# See if we can remove after properly handling cancellation for agent.
63+
await asyncio.sleep(0.1)
64+
65+
self._stream.close()
66+
self._audio.terminate()
67+
68+
self._stream = None
69+
self._audio = None
3270

3371
async def __call__(self) -> BidiAudioInputEvent:
34-
"""Read audio from microphone."""
35-
audio_bytes = self.audio.input_stream.read(self.audio.chunk_size, exception_on_overflow=False)
72+
"""Read audio from input stream."""
73+
audio_bytes = await asyncio.to_thread(
74+
self._stream.read, self._frames_per_buffer, exception_on_overflow=False
75+
)
3676

3777
return BidiAudioInputEvent(
3878
audio=base64.b64encode(audio_bytes).decode("utf-8"),
39-
format="pcm",
40-
sample_rate=self.audio.input_sample_rate,
41-
channels=self.audio.input_channels,
79+
channels=self._channels,
80+
format=_BidiAudioInput._ENCODING,
81+
sample_rate=self._rate,
4282
)
4383

4484

4585
class _BidiAudioOutput(BidiOutput):
46-
"""Handle audio output from bidi agent."""
47-
def __init__(self, audio: "BidiAudioIO") -> None:
48-
"""Store reference to pyaudio instance."""
49-
self.audio = audio
86+
"""Handle audio output from bidi agent.
87+
88+
Attributes:
89+
_audio: PyAudio instance for audio system access.
90+
_stream: Audio output stream.
91+
_buffer: Deque buffer for queuing audio data.
92+
_buffer_event: Event to signal when buffer has data.
93+
_output_task: Background task for processing audio output.
94+
"""
95+
96+
_audio: pyaudio.PyAudio
97+
_stream: pyaudio.Stream
98+
_buffer: deque
99+
_buffer_event: asyncio.Event
100+
_output_task: asyncio.Task
101+
102+
_BUFFER_SIZE: int | None = None
103+
_CHANNELS: int = 1
104+
_DEVICE_INDEX: int | None = None
105+
_FORMAT: int = pyaudio.paInt16
106+
_FRAMES_PER_BUFFER: int = 512
107+
_RATE: int = 16000
108+
109+
def __init__(self, config: dict[str, Any]) -> None:
110+
"""Extract configs."""
111+
self._buffer_size = config.get("output_buffer_size", _BidiAudioOutput._BUFFER_SIZE)
112+
self._channels = config.get("output_channels", _BidiAudioOutput._CHANNELS)
113+
self._device_index = config.get("output_device_index", _BidiAudioOutput._DEVICE_INDEX)
114+
self._format = config.get("output_format", _BidiAudioOutput._FORMAT)
115+
self._frames_per_buffer = config.get("output_frames_per_buffer", _BidiAudioOutput._FRAMES_PER_BUFFER)
116+
self._rate = config.get("output_rate", _BidiAudioOutput._RATE)
50117

51118
async def start(self) -> None:
52-
"""Start audio output."""
53-
self.audio._start()
119+
"""Start output stream."""
120+
self._audio = pyaudio.PyAudio()
121+
self._stream = self._audio.open(
122+
channels=self._channels,
123+
format=self._format,
124+
frames_per_buffer=self._frames_per_buffer,
125+
output=True,
126+
output_device_index=self._device_index,
127+
rate=self._rate,
128+
)
129+
self._buffer = deque(maxlen=self._buffer_size)
130+
self._buffer_event = asyncio.Event()
131+
self._output_task = asyncio.create_task(self._output())
54132

55133
async def stop(self) -> None:
56-
"""Stop audio output."""
57-
self.audio._stop()
134+
"""Stop output stream."""
135+
self._buffer.clear()
136+
self._buffer.append(None)
137+
self._buffer_event.set()
138+
await self._output_task
139+
140+
self._stream.close()
141+
self._audio.terminate()
142+
143+
self._output_task = None
144+
self._buffer = None
145+
self._buffer_event = None
146+
self._stream = None
147+
self._audio = None
58148

59149
async def __call__(self, event: BidiOutputEvent) -> None:
60150
"""Handle audio events with direct stream writing."""
61151
if isinstance(event, BidiAudioStreamEvent):
62-
self.audio.output_stream.write(base64.b64decode(event["audio"]))
152+
audio_bytes = base64.b64decode(event["audio"])
153+
self._buffer.append(audio_bytes)
154+
self._buffer_event.set()
155+
156+
elif isinstance(event, BidiInterruptionEvent):
157+
self._buffer.clear()
158+
self._buffer_event.clear()
63159

64-
# TODO: Outputing audio to speakers is a sync operation. Adding sleep to prevent event loop hogging. Will
65-
# follow up on identifying a cleaner approach.
66-
await asyncio.sleep(0.01)
160+
async def _output(self) -> None:
161+
while True:
162+
await self._buffer_event.wait()
163+
self._buffer_event.clear()
164+
165+
while self._buffer:
166+
audio_bytes = self._buffer.popleft()
167+
if not audio_bytes:
168+
return
169+
170+
await asyncio.to_thread(self._stream.write, audio_bytes)
67171

68172

69173
class BidiAudioIO:
70-
"""Audio IO channel for BidiAgent with direct stream processing."""
174+
"""Send and receive audio data from devices."""
71175

72-
def __init__(
73-
self,
74-
audio_config: dict | None = None,
75-
):
76-
"""Initialize AudioIO with clean audio configuration.
176+
def __init__(self, **config: Any) -> None:
177+
"""Initialize audio devices.
77178
78179
Args:
79-
audio_config: Dictionary containing audio configuration:
80-
- input_sample_rate (int): Microphone sample rate (default: 24000)
81-
- output_sample_rate (int): Speaker sample rate (default: 24000)
82-
- chunk_size (int): Audio chunk size in bytes (default: 1024)
83-
- input_device_index (int): Specific input device (optional)
84-
- output_device_index (int): Specific output device (optional)
180+
**config: Dictionary containing audio configuration:
85181
- input_channels (int): Input channels (default: 1)
182+
- input_device_index (int): Specific input device (optional)
183+
- input_format (int): Audio format (default: paInt16)
184+
- input_frames_per_buffer (int): Frames per buffer (default: 512)
185+
- input_rate (int): Input sample rate (default: 16000)
186+
- output_buffer_size (int): Maximum output buffer size (default: None)
86187
- output_channels (int): Output channels (default: 1)
188+
- output_device_index (int): Specific output device (optional)
189+
- output_format (int): Audio format (default: paInt16)
190+
- output_frames_per_buffer (int): Frames per buffer (default: 512)
191+
- output_rate (int): Output sample rate (default: 16000)
87192
"""
88-
default_config = {
89-
"input_sample_rate": 16000,
90-
"output_sample_rate": 16000,
91-
"chunk_size": 512,
92-
"input_device_index": None,
93-
"output_device_index": None,
94-
"input_channels": 1,
95-
"output_channels": 1,
96-
}
97-
98-
# Merge user config with defaults
99-
if audio_config:
100-
default_config.update(audio_config)
101-
102-
# Set audio configuration attributes
103-
self.input_sample_rate = default_config["input_sample_rate"]
104-
self.output_sample_rate = default_config["output_sample_rate"]
105-
self.chunk_size = default_config["chunk_size"]
106-
self.input_device_index = default_config["input_device_index"]
107-
self.output_device_index = default_config["output_device_index"]
108-
self.input_channels = default_config["input_channels"]
109-
self.output_channels = default_config["output_channels"]
110-
111-
# Audio infrastructure
112-
self.audio = None
113-
self.input_stream = None
114-
self.output_stream = None
115-
self.interrupted = False
193+
self._config = config
116194

117195
def input(self) -> _BidiAudioInput:
118196
"""Return audio processing BidiInput"""
119-
return _BidiAudioInput(self)
197+
return _BidiAudioInput(self._config)
120198

121199
def output(self) -> _BidiAudioOutput:
122200
"""Return audio processing BidiOutput"""
123-
return _BidiAudioOutput(self)
124-
125-
def _start(self) -> None:
126-
"""Setup PyAudio streams for input and output."""
127-
if self.audio:
128-
return
129-
130-
self.audio = pyaudio.PyAudio()
131-
132-
self.input_stream = self.audio.open(
133-
format=pyaudio.paInt16,
134-
channels=self.input_channels,
135-
rate=self.input_sample_rate,
136-
input=True,
137-
frames_per_buffer=self.chunk_size,
138-
input_device_index=self.input_device_index,
139-
)
140-
141-
self.output_stream = self.audio.open(
142-
format=pyaudio.paInt16,
143-
channels=self.output_channels,
144-
rate=self.output_sample_rate,
145-
output=True,
146-
frames_per_buffer=self.chunk_size,
147-
output_device_index=self.output_device_index,
148-
)
149-
150-
def _stop(self) -> None:
151-
"""Clean up IO channel resources."""
152-
if not self.audio:
153-
return
154-
155-
self.input_stream.close()
156-
self.output_stream.close()
157-
self.audio.terminate()
158-
159-
self.input_stream = None
160-
self.output_stream = None
161-
self.audio = None
201+
return _BidiAudioOutput(self._config)

src/strands/experimental/bidi/models/novasonic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -519,7 +519,7 @@ def _convert_nova_event(self, nova_event: dict[str, any]) -> BidiOutputEvent | N
519519
return BidiAudioStreamEvent(
520520
audio=audio_content,
521521
format="pcm",
522-
sample_rate=24000,
522+
sample_rate=NOVA_AUDIO_OUTPUT_CONFIG["sampleRateHertz"],
523523
channels=1
524524
)
525525

src/strands/experimental/bidi/scripts/test_bidi.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ async def main():
1717

1818

1919
# Nova Sonic model
20-
audio_io = BidiAudioIO(audio_config={})
20+
audio_io = BidiAudioIO()
2121
text_io = BidiTextIO()
2222
model = BidiNovaSonicModel(region="us-east-1")
2323

tests/strands/experimental/bidi/io/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)