Skip to content

Commit 779d9cd

Browse files
committed
test(client): add comprehensive unit tests for 100% coverage
1 parent c87a69a commit 779d9cd

File tree

2 files changed

+126
-128
lines changed

2 files changed

+126
-128
lines changed

packages/toolbox-adk/src/toolbox_adk/client.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222

2323
from .credentials import CredentialConfig, CredentialType
2424

25-
# Global ContextVar for User Identity (3LO) tokens to be injected per-request
2625
USER_TOKEN_CONTEXT_VAR: ContextVar[Optional[str]] = ContextVar(
2726
"toolbox_user_token", default=None
2827
)
@@ -53,11 +52,6 @@ def __init__(
5352
self._credentials = credentials
5453
self._additional_headers = additional_headers or {}
5554

56-
# Prepare auth_token_getters for toolbox-core
57-
# toolbox_core expects: dict[str, Callable[[], str | Awaitable[str]]]
58-
# However, for general headers (like Authorization), we can pass them in client_headers
59-
# if they are static or simpler. Toolbox-core supports `client_headers` which can be dynamic.
60-
6155
self._core_client_headers: Dict[
6256
str, Union[str, Callable[[], str], Callable[[], Awaitable[str]]]
6357
] = {}
@@ -85,8 +79,6 @@ def _configure_auth(self, creds: CredentialConfig) -> None:
8579
)
8680

8781
# Create an async capable token getter
88-
# We wrap it to match the signature expected by toolbox-core headers
89-
# (which accepts callables)
9082
self._core_client_headers["Authorization"] = self._create_adc_token_getter(
9183
creds.target_audience
9284
)
@@ -108,14 +100,10 @@ def _configure_auth(self, creds: CredentialConfig) -> None:
108100

109101
elif creds.type == CredentialType.USER_IDENTITY:
110102
# For USER_IDENTITY (3LO), the *Tool* handles the interactive flow at runtime.
111-
# We use a ContextVar to inject the token per-request.
112103

113104
def get_user_token() -> str:
114105
token = USER_TOKEN_CONTEXT_VAR.get()
115106
if not token:
116-
# If this is called but no token is set in context, it means
117-
# the tool wrapper failed to set it or we are in a context where
118-
# we expected it. We return empty string which might cause 401.
119107
return ""
120108
return f"Bearer {token}"
121109

@@ -125,28 +113,19 @@ def _create_adc_token_getter(self, audience: str) -> Callable[[], str]:
125113
"""Returns a callable that fetches a fresh ID token using ADC."""
126114

127115
def get_token() -> str:
128-
# Note: This is a synchronous call. Toolbox-core supports sync callables in headers.
129-
# Ideally we would use async but google-auth is primarily sync for these helpers.
130116
request = transport.requests.Request()
131-
# Try to get ID token directly (e.g. on Cloud Run)
132117
try:
133118
token = id_token.fetch_id_token(request, audience)
134119
return f"Bearer {token}"
135120
except Exception:
136-
# Fallback to default credentials (e.g. local gcloud)
121+
# Fallback to default credentials
137122
creds, _ = google.auth.default()
138123
if not creds.valid:
139124
creds.refresh(request)
140-
# If specific ID token credentials, use them, otherwise this might be Access Token (scoped)
141-
# For Toolbox we usually need ID Tokens.
142-
# If the user is locally auth'd via `gcloud auth login`, fetch_id_token is preferred.
143-
# If falling back to service account file:
125+
144126
if hasattr(creds, "id_token") and creds.id_token:
145127
return f"Bearer {creds.id_token}"
146128

147-
# If we are here, we might need to manually sign via IAM or similar if it's a generic SA.
148-
# For simplicity in this v1, we assume fetch_id_token works or standard creds work.
149-
# Re-attempt fetch_id_token on the credentials object if possible
150129
curr_token = getattr(creds, "token", None)
151130
return f"Bearer {curr_token}" if curr_token else ""
152131

packages/toolbox-adk/tests/unit/test_client.py

Lines changed: 124 additions & 105 deletions
Original file line numberDiff line numberDiff line change
@@ -12,159 +12,178 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from unittest.mock import ANY, AsyncMock, MagicMock, patch
15+
import unittest
16+
from unittest.mock import AsyncMock, MagicMock, patch
1617

1718
import pytest
1819

19-
from toolbox_adk.client import ToolboxClient
20-
from toolbox_adk.credentials import CredentialStrategy
20+
from toolbox_adk import CredentialStrategy, ToolboxClient
21+
from toolbox_adk.client import CredentialType
2122

2223

23-
class TestToolboxClient:
24+
@pytest.mark.asyncio
25+
class TestToolboxClientAuth:
26+
"""Unit tests for Client Auth logic."""
2427

2528
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
26-
def test_init_no_auth(self, mock_core_client):
29+
async def test_init_toolbox_identity(self, mock_core_client):
30+
"""Test init with TOOLBOX_IDENTITY (no auth headers)."""
2731
creds = CredentialStrategy.TOOLBOX_IDENTITY()
28-
client = ToolboxClient("http://server", credentials=creds)
32+
client = ToolboxClient(server_url="http://test", credentials=creds)
2933

30-
mock_core_client.assert_called_with(
31-
server_url="http://server", client_headers={}
32-
)
34+
# Verify core client created with empty headers for auth
35+
_, kwargs = mock_core_client.call_args
36+
assert "client_headers" in kwargs
37+
headers = kwargs["client_headers"]
38+
assert "Authorization" not in headers
3339

3440
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
35-
def test_init_manual_token(self, mock_core_client):
36-
creds = CredentialStrategy.MANUAL_TOKEN("abc")
37-
client = ToolboxClient("http://server", credentials=creds)
41+
@patch("toolbox_adk.client.id_token.fetch_id_token")
42+
@patch("toolbox_adk.client.google.auth.default")
43+
@patch("toolbox_adk.client.transport.requests.Request")
44+
async def test_init_adc_success_fetch_id_token(
45+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
46+
):
47+
"""Test ADC strategy where fetch_id_token succeeds."""
48+
mock_fetch_id.return_value = "id-token-123"
3849

39-
mock_core_client.assert_called_with(
40-
server_url="http://server", client_headers={"Authorization": "Bearer abc"}
50+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
51+
target_audience="aud"
4152
)
53+
client = ToolboxClient(server_url="http://test", credentials=creds)
4254

43-
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
44-
def test_init_additional_headers(self, mock_core_client):
45-
creds = CredentialStrategy.TOOLBOX_IDENTITY()
46-
headers = {"X-Custom": "Val"}
47-
client = ToolboxClient(
48-
"http://server", credentials=creds, additional_headers=headers
49-
)
55+
_, kwargs = mock_core_client.call_args
56+
headers = kwargs["client_headers"]
57+
assert "Authorization" in headers
58+
token_getter = headers["Authorization"]
59+
assert callable(token_getter)
5060

51-
mock_core_client.assert_called_with(
52-
server_url="http://server", client_headers={"X-Custom": "Val"}
53-
)
61+
# Call the getter
62+
token_val = token_getter()
63+
assert token_val == "Bearer id-token-123"
64+
mock_fetch_id.assert_called()
5465

5566
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
5667
@patch("toolbox_adk.client.id_token.fetch_id_token")
57-
def test_adc_auth_flow_success(self, mock_fetch_token, mock_core_client):
58-
mock_fetch_token.return_value = "id_token_123"
68+
@patch("toolbox_adk.client.google.auth.default")
69+
@patch("toolbox_adk.client.transport.requests.Request")
70+
async def test_init_adc_fallback_creds(
71+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
72+
):
73+
"""Test ADC strategy fallback to default() when fetch_id_token fails."""
74+
mock_fetch_id.side_effect = Exception("No metadata server")
5975

60-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
61-
client = ToolboxClient("http://server", credentials=creds)
76+
# Mock default creds
77+
mock_creds = MagicMock()
78+
mock_creds.valid = False
79+
mock_creds.id_token = "fallback-id-token"
80+
mock_default.return_value = (mock_creds, "proj")
6281

63-
# Verify a callable was passed
64-
args, kwargs = mock_core_client.call_args
65-
assert "Authorization" in kwargs["client_headers"]
66-
token_getter = kwargs["client_headers"]["Authorization"]
67-
assert callable(token_getter)
82+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
83+
target_audience="aud"
84+
)
85+
client = ToolboxClient(server_url="http://test", credentials=creds)
6886

69-
# Verify callable behavior
87+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
7088
token = token_getter()
71-
assert token == "Bearer id_token_123"
72-
mock_fetch_token.assert_called()
89+
assert token == "Bearer fallback-id-token"
90+
mock_creds.refresh.assert_called() # Because we set valid=False
7391

7492
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
7593
@patch("toolbox_adk.client.id_token.fetch_id_token")
7694
@patch("toolbox_adk.client.google.auth.default")
77-
def test_adc_auth_flow_fallback(
78-
self, mock_default, mock_fetch_token, mock_core_client
95+
@patch("toolbox_adk.client.transport.requests.Request")
96+
async def test_init_adc_fallback_creds_token(
97+
self, mock_req, mock_default, mock_fetch_id, mock_core_client
7998
):
80-
# unexpected error on fetch_id_token
81-
mock_fetch_token.side_effect = Exception("No metadata")
99+
"""Test ADC fallback when creds have .token but no .id_token."""
100+
mock_fetch_id.side_effect = Exception("No metadata server")
82101

83102
mock_creds = MagicMock()
84-
mock_creds.valid = False
85-
mock_creds.id_token = "fallback_id_token"
103+
mock_creds.valid = True
104+
del mock_creds.id_token # Simulate no id_token attr or None
105+
mock_creds.token = "access-token-123" # e.g. user creds
86106
mock_default.return_value = (mock_creds, "proj")
87107

88-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
89-
client = ToolboxClient("http://server", credentials=creds)
90-
108+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(
109+
target_audience="aud"
110+
)
111+
client = ToolboxClient(server_url="http://test", credentials=creds)
91112
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
92-
token = token_getter()
113+
assert token_getter() == "Bearer access-token-123"
93114

94-
assert token == "Bearer fallback_id_token"
95-
mock_creds.refresh.assert_called()
115+
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
116+
async def test_init_manual_token(self, mock_core_client):
117+
creds = CredentialStrategy.MANUAL_TOKEN(token="abc")
118+
client = ToolboxClient("http://test", credentials=creds)
119+
headers = mock_core_client.call_args[1]["client_headers"]
120+
assert headers["Authorization"] == "Bearer abc"
96121

97122
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
98-
def test_manual_creds(self, mock_core_client):
99-
mock_g_creds = MagicMock()
100-
mock_g_creds.valid = False
101-
mock_g_creds.token = "oauth_token"
123+
async def test_init_manual_creds(self, mock_core_client):
124+
mock_google_creds = MagicMock()
125+
mock_google_creds.valid = True
126+
mock_google_creds.token = "creds-token"
102127

103-
creds = CredentialStrategy.MANUAL_CREDS(mock_g_creds)
104-
client = ToolboxClient("http://server", credentials=creds)
128+
creds = CredentialStrategy.MANUAL_CREDS(credentials=mock_google_creds)
129+
client = ToolboxClient("http://test", credentials=creds)
105130

106131
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
107-
token = token_getter()
108-
109-
assert token == "Bearer oauth_token"
110-
assert token == "Bearer oauth_token"
111-
mock_g_creds.refresh.assert_called()
132+
assert token_getter() == "Bearer creds-token"
112133

113134
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
114-
def test_init_validation_errors(self, mock_core_client):
115-
# ADC missing audience
116-
with pytest.raises(ValueError, match="target_audience is required"):
117-
# Fix: only pass target_audience as keyword arg OR positional, not both mixed in a way that causes overlap if defined so
118-
# Actually simpler: just pass raw None
119-
ToolboxClient(
120-
"url",
121-
credentials=CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(None),
122-
)
123-
124-
# Manual token missing token
125-
with pytest.raises(ValueError, match="token is required"):
126-
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_TOKEN(None))
127-
128-
# Manual creds missing credentials
129-
with pytest.raises(ValueError, match="credentials object is required"):
130-
ToolboxClient("url", credentials=CredentialStrategy.MANUAL_CREDS(None))
135+
async def test_init_user_identity(self, mock_core_client):
136+
creds = CredentialStrategy.USER_IDENTITY(client_id="c", client_secret="s")
137+
client = ToolboxClient("http://test", credentials=creds)
131138

132-
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
133-
@patch("toolbox_adk.client.id_token.fetch_id_token")
134-
@patch("toolbox_adk.client.google.auth.default")
135-
def test_adc_auth_flow_fallback_access_token(
136-
self, mock_default, mock_fetch_token, mock_core_client
137-
):
138-
# fetch_id_token fails
139-
mock_fetch_token.side_effect = Exception("No metadata")
139+
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
140+
# Should be empty initially
141+
assert token_getter() == ""
140142

141-
mock_creds = MagicMock()
142-
mock_creds.valid = False
143-
mock_creds.id_token = None # No ID token
144-
mock_creds.token = "access_token_123"
145-
mock_default.return_value = (mock_creds, "proj")
143+
# Set context
144+
from toolbox_adk.client import USER_TOKEN_CONTEXT_VAR
146145

147-
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS("http://aud")
148-
client = ToolboxClient("http://server", credentials=creds)
146+
token = USER_TOKEN_CONTEXT_VAR.set("user-tok")
147+
try:
148+
assert token_getter() == "Bearer user-tok"
149+
finally:
150+
USER_TOKEN_CONTEXT_VAR.reset(token)
149151

150-
token_getter = mock_core_client.call_args[1]["client_headers"]["Authorization"]
151-
token = token_getter()
152+
async def test_validation_errors(self):
153+
with pytest.raises(ValueError):
154+
# ADC requires audience
155+
creds = CredentialStrategy.APPLICATION_DEFAULT_CREDENTIALS(target_audience="")
156+
ToolboxClient("http://test", credentials=creds)
152157

153-
assert token == "Bearer access_token_123"
154-
mock_creds.refresh.assert_called()
158+
with pytest.raises(ValueError):
159+
creds = CredentialStrategy.MANUAL_TOKEN(token="")
160+
ToolboxClient("http://test", credentials=creds)
161+
162+
with pytest.raises(ValueError):
163+
creds = CredentialStrategy.MANUAL_CREDS(credentials=None)
164+
ToolboxClient("http://test", credentials=creds)
155165

156-
@pytest.mark.asyncio
157166
@patch("toolbox_adk.client.toolbox_core.ToolboxClient")
158-
async def test_delegation(self, mock_core_client):
159-
mock_instance = mock_core_client.return_value
160-
mock_instance.load_toolset = AsyncMock(return_value=["t1"])
161-
mock_instance.close = AsyncMock()
167+
async def test_load_methods(self, mock_core_client_class):
168+
# Setup mock instance
169+
mock_instance = AsyncMock()
170+
mock_core_client_class.return_value = mock_instance
162171

163-
client = ToolboxClient("http://server")
164-
tools = await client.load_toolset("my-set", extra="arg")
172+
client = ToolboxClient(
173+
"http://test", credentials=CredentialStrategy.TOOLBOX_IDENTITY()
174+
)
165175

166-
mock_instance.load_toolset.assert_awaited_with("my-set", extra="arg")
167-
assert tools == ["t1"]
176+
# Test load_toolset
177+
await client.load_toolset("ts", foo="bar")
178+
mock_instance.load_toolset.assert_called_with("ts", foo="bar")
168179

180+
# Test load_tool
181+
await client.load_tool("t", baz="qux")
182+
mock_instance.load_tool.assert_called_with("t", baz="qux")
183+
184+
# Test close
169185
await client.close()
170-
mock_instance.close.assert_awaited()
186+
mock_instance.close.assert_called_once()
187+
188+
# Test property
189+
assert client.credential_config is not None

0 commit comments

Comments
 (0)