Skip to content

Commit c0bcc45

Browse files
authored
Merge pull request #15 from stainless-sdks/bbatha/agent-endpoint-in-constructor
move agent endpoint to the constructor
2 parents c14b456 + 5610f5f commit c0bcc45

File tree

4 files changed

+222
-111
lines changed

4 files changed

+222
-111
lines changed

src/do_gradientai/_client.py

Lines changed: 132 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,18 @@
6161
from .resources.droplets.droplets import DropletsResource, AsyncDropletsResource
6262
from .resources.firewalls.firewalls import FirewallsResource, AsyncFirewallsResource
6363
from .resources.inference.inference import InferenceResource, AsyncInferenceResource
64-
from .resources.floating_ips.floating_ips import FloatingIPsResource, AsyncFloatingIPsResource
65-
from .resources.load_balancers.load_balancers import LoadBalancersResource, AsyncLoadBalancersResource
66-
from .resources.knowledge_bases.knowledge_bases import KnowledgeBasesResource, AsyncKnowledgeBasesResource
64+
from .resources.floating_ips.floating_ips import (
65+
FloatingIPsResource,
66+
AsyncFloatingIPsResource,
67+
)
68+
from .resources.load_balancers.load_balancers import (
69+
LoadBalancersResource,
70+
AsyncLoadBalancersResource,
71+
)
72+
from .resources.knowledge_bases.knowledge_bases import (
73+
KnowledgeBasesResource,
74+
AsyncKnowledgeBasesResource,
75+
)
6776

6877
__all__ = [
6978
"Timeout",
@@ -82,15 +91,15 @@ class GradientAI(SyncAPIClient):
8291
api_key: str | None
8392
inference_key: str | None
8493
agent_key: str | None
85-
agent_domain: str | None
94+
_agent_endpoint: str | None
8695

8796
def __init__(
8897
self,
8998
*,
9099
api_key: str | None = None,
91100
inference_key: str | None = None,
92101
agent_key: str | None = None,
93-
agent_domain: str | None = None,
102+
agent_endpoint: str | None = None,
94103
base_url: str | httpx.URL | None = None,
95104
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
96105
max_retries: int = DEFAULT_MAX_RETRIES,
@@ -129,7 +138,7 @@ def __init__(
129138
agent_key = os.environ.get("GRADIENTAI_AGENT_KEY")
130139
self.agent_key = agent_key
131140

132-
self.agent_domain = agent_domain
141+
self._agent_endpoint = agent_endpoint
133142

134143
if base_url is None:
135144
base_url = os.environ.get("GRADIENT_AI_BASE_URL")
@@ -150,6 +159,19 @@ def __init__(
150159

151160
self._default_stream_cls = Stream
152161

162+
@cached_property
163+
def agent_endpoint(self) -> str:
164+
"""
165+
Returns the agent endpoint URL.
166+
"""
167+
if self._agent_endpoint is None:
168+
raise ValueError(
169+
"Agent endpoint is not set. Please provide an agent endpoint when initializing the client."
170+
)
171+
if self._agent_endpoint.startswith("https://"):
172+
return self._agent_endpoint
173+
return "https://" + self._agent_endpoint
174+
153175
@cached_property
154176
def agents(self) -> AgentsResource:
155177
from .resources.agents import AgentsResource
@@ -272,7 +294,9 @@ def default_headers(self) -> dict[str, str | Omit]:
272294

273295
@override
274296
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
275-
if (self.api_key or self.agent_key or self.inference_key) and headers.get("Authorization"):
297+
if (self.api_key or self.agent_key or self.inference_key) and headers.get(
298+
"Authorization"
299+
):
276300
return
277301
if isinstance(custom_headers.get("Authorization"), Omit):
278302
return
@@ -287,7 +311,7 @@ def copy(
287311
api_key: str | None = None,
288312
inference_key: str | None = None,
289313
agent_key: str | None = None,
290-
agent_domain: str | None = None,
314+
agent_endpoint: str | None = None,
291315
base_url: str | httpx.URL | None = None,
292316
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
293317
http_client: httpx.Client | None = None,
@@ -302,10 +326,14 @@ def copy(
302326
Create a new client instance re-using the same options given to the current client with optional overriding.
303327
"""
304328
if default_headers is not None and set_default_headers is not None:
305-
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
329+
raise ValueError(
330+
"The `default_headers` and `set_default_headers` arguments are mutually exclusive"
331+
)
306332

307333
if default_query is not None and set_default_query is not None:
308-
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
334+
raise ValueError(
335+
"The `default_query` and `set_default_query` arguments are mutually exclusive"
336+
)
309337

310338
headers = self._custom_headers
311339
if default_headers is not None:
@@ -324,7 +352,7 @@ def copy(
324352
api_key=api_key or self.api_key,
325353
inference_key=inference_key or self.inference_key,
326354
agent_key=agent_key or self.agent_key,
327-
agent_domain=agent_domain or self.agent_domain,
355+
agent_endpoint=agent_endpoint or self._agent_endpoint,
328356
base_url=base_url or self.base_url,
329357
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
330358
http_client=http_client,
@@ -352,10 +380,14 @@ def _make_status_error(
352380
return _exceptions.BadRequestError(err_msg, response=response, body=body)
353381

354382
if response.status_code == 401:
355-
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
383+
return _exceptions.AuthenticationError(
384+
err_msg, response=response, body=body
385+
)
356386

357387
if response.status_code == 403:
358-
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
388+
return _exceptions.PermissionDeniedError(
389+
err_msg, response=response, body=body
390+
)
359391

360392
if response.status_code == 404:
361393
return _exceptions.NotFoundError(err_msg, response=response, body=body)
@@ -364,13 +396,17 @@ def _make_status_error(
364396
return _exceptions.ConflictError(err_msg, response=response, body=body)
365397

366398
if response.status_code == 422:
367-
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
399+
return _exceptions.UnprocessableEntityError(
400+
err_msg, response=response, body=body
401+
)
368402

369403
if response.status_code == 429:
370404
return _exceptions.RateLimitError(err_msg, response=response, body=body)
371405

372406
if response.status_code >= 500:
373-
return _exceptions.InternalServerError(err_msg, response=response, body=body)
407+
return _exceptions.InternalServerError(
408+
err_msg, response=response, body=body
409+
)
374410
return APIStatusError(err_msg, response=response, body=body)
375411

376412

@@ -379,15 +415,15 @@ class AsyncGradientAI(AsyncAPIClient):
379415
api_key: str | None
380416
inference_key: str | None
381417
agent_key: str | None
382-
agent_domain: str | None
418+
_agent_endpoint: str | None
383419

384420
def __init__(
385421
self,
386422
*,
387423
api_key: str | None = None,
388424
inference_key: str | None = None,
389425
agent_key: str | None = None,
390-
agent_domain: str | None = None,
426+
agent_endpoint: str | None = None,
391427
base_url: str | httpx.URL | None = None,
392428
timeout: Union[float, Timeout, None, NotGiven] = NOT_GIVEN,
393429
max_retries: int = DEFAULT_MAX_RETRIES,
@@ -426,7 +462,7 @@ def __init__(
426462
agent_key = os.environ.get("GRADIENTAI_AGENT_KEY")
427463
self.agent_key = agent_key
428464

429-
self.agent_domain = agent_domain
465+
self._agent_endpoint = agent_endpoint
430466

431467
if base_url is None:
432468
base_url = os.environ.get("GRADIENT_AI_BASE_URL")
@@ -447,6 +483,19 @@ def __init__(
447483

448484
self._default_stream_cls = AsyncStream
449485

486+
@cached_property
487+
def agent_endpoint(self) -> str:
488+
"""
489+
Returns the agent endpoint URL.
490+
"""
491+
if self._agent_endpoint is None:
492+
raise ValueError(
493+
"Agent endpoint is not set. Please provide an agent endpoint when initializing the client."
494+
)
495+
if self._agent_endpoint.startswith("https://"):
496+
return self._agent_endpoint
497+
return "https://" + self._agent_endpoint
498+
450499
@cached_property
451500
def agents(self) -> AsyncAgentsResource:
452501
from .resources.agents import AsyncAgentsResource
@@ -569,7 +618,9 @@ def default_headers(self) -> dict[str, str | Omit]:
569618

570619
@override
571620
def _validate_headers(self, headers: Headers, custom_headers: Headers) -> None:
572-
if (self.api_key or self.agent_key or self.inference_key) and headers.get("Authorization"):
621+
if (self.api_key or self.agent_key or self.inference_key) and headers.get(
622+
"Authorization"
623+
):
573624
return
574625
if isinstance(custom_headers.get("Authorization"), Omit):
575626
return
@@ -584,7 +635,7 @@ def copy(
584635
api_key: str | None = None,
585636
inference_key: str | None = None,
586637
agent_key: str | None = None,
587-
agent_domain: str | None = None,
638+
agent_endpoint: str | None = None,
588639
base_url: str | httpx.URL | None = None,
589640
timeout: float | Timeout | None | NotGiven = NOT_GIVEN,
590641
http_client: httpx.AsyncClient | None = None,
@@ -599,10 +650,14 @@ def copy(
599650
Create a new client instance re-using the same options given to the current client with optional overriding.
600651
"""
601652
if default_headers is not None and set_default_headers is not None:
602-
raise ValueError("The `default_headers` and `set_default_headers` arguments are mutually exclusive")
653+
raise ValueError(
654+
"The `default_headers` and `set_default_headers` arguments are mutually exclusive"
655+
)
603656

604657
if default_query is not None and set_default_query is not None:
605-
raise ValueError("The `default_query` and `set_default_query` arguments are mutually exclusive")
658+
raise ValueError(
659+
"The `default_query` and `set_default_query` arguments are mutually exclusive"
660+
)
606661

607662
headers = self._custom_headers
608663
if default_headers is not None:
@@ -621,7 +676,7 @@ def copy(
621676
api_key=api_key or self.api_key,
622677
inference_key=inference_key or self.inference_key,
623678
agent_key=agent_key or self.agent_key,
624-
agent_domain=agent_domain or self.agent_domain,
679+
agent_endpoint=agent_endpoint or self._agent_endpoint,
625680
base_url=base_url or self.base_url,
626681
timeout=self.timeout if isinstance(timeout, NotGiven) else timeout,
627682
http_client=http_client,
@@ -649,10 +704,14 @@ def _make_status_error(
649704
return _exceptions.BadRequestError(err_msg, response=response, body=body)
650705

651706
if response.status_code == 401:
652-
return _exceptions.AuthenticationError(err_msg, response=response, body=body)
707+
return _exceptions.AuthenticationError(
708+
err_msg, response=response, body=body
709+
)
653710

654711
if response.status_code == 403:
655-
return _exceptions.PermissionDeniedError(err_msg, response=response, body=body)
712+
return _exceptions.PermissionDeniedError(
713+
err_msg, response=response, body=body
714+
)
656715

657716
if response.status_code == 404:
658717
return _exceptions.NotFoundError(err_msg, response=response, body=body)
@@ -661,13 +720,17 @@ def _make_status_error(
661720
return _exceptions.ConflictError(err_msg, response=response, body=body)
662721

663722
if response.status_code == 422:
664-
return _exceptions.UnprocessableEntityError(err_msg, response=response, body=body)
723+
return _exceptions.UnprocessableEntityError(
724+
err_msg, response=response, body=body
725+
)
665726

666727
if response.status_code == 429:
667728
return _exceptions.RateLimitError(err_msg, response=response, body=body)
668729

669730
if response.status_code >= 500:
670-
return _exceptions.InternalServerError(err_msg, response=response, body=body)
731+
return _exceptions.InternalServerError(
732+
err_msg, response=response, body=body
733+
)
671734
return APIStatusError(err_msg, response=response, body=body)
672735

673736

@@ -793,8 +856,12 @@ def regions(self) -> regions.AsyncRegionsResourceWithRawResponse:
793856
return AsyncRegionsResourceWithRawResponse(self._client.regions)
794857

795858
@cached_property
796-
def knowledge_bases(self) -> knowledge_bases.AsyncKnowledgeBasesResourceWithRawResponse:
797-
from .resources.knowledge_bases import AsyncKnowledgeBasesResourceWithRawResponse
859+
def knowledge_bases(
860+
self,
861+
) -> knowledge_bases.AsyncKnowledgeBasesResourceWithRawResponse:
862+
from .resources.knowledge_bases import (
863+
AsyncKnowledgeBasesResourceWithRawResponse,
864+
)
798865

799866
return AsyncKnowledgeBasesResourceWithRawResponse(self._client.knowledge_bases)
800867

@@ -835,7 +902,9 @@ def images(self) -> images.AsyncImagesResourceWithRawResponse:
835902
return AsyncImagesResourceWithRawResponse(self._client.images)
836903

837904
@cached_property
838-
def load_balancers(self) -> load_balancers.AsyncLoadBalancersResourceWithRawResponse:
905+
def load_balancers(
906+
self,
907+
) -> load_balancers.AsyncLoadBalancersResourceWithRawResponse:
839908
from .resources.load_balancers import AsyncLoadBalancersResourceWithRawResponse
840909

841910
return AsyncLoadBalancersResourceWithRawResponse(self._client.load_balancers)
@@ -890,8 +959,12 @@ def regions(self) -> regions.RegionsResourceWithStreamingResponse:
890959
return RegionsResourceWithStreamingResponse(self._client.regions)
891960

892961
@cached_property
893-
def knowledge_bases(self) -> knowledge_bases.KnowledgeBasesResourceWithStreamingResponse:
894-
from .resources.knowledge_bases import KnowledgeBasesResourceWithStreamingResponse
962+
def knowledge_bases(
963+
self,
964+
) -> knowledge_bases.KnowledgeBasesResourceWithStreamingResponse:
965+
from .resources.knowledge_bases import (
966+
KnowledgeBasesResourceWithStreamingResponse,
967+
)
895968

896969
return KnowledgeBasesResourceWithStreamingResponse(self._client.knowledge_bases)
897970

@@ -932,7 +1005,9 @@ def images(self) -> images.ImagesResourceWithStreamingResponse:
9321005
return ImagesResourceWithStreamingResponse(self._client.images)
9331006

9341007
@cached_property
935-
def load_balancers(self) -> load_balancers.LoadBalancersResourceWithStreamingResponse:
1008+
def load_balancers(
1009+
self,
1010+
) -> load_balancers.LoadBalancersResourceWithStreamingResponse:
9361011
from .resources.load_balancers import LoadBalancersResourceWithStreamingResponse
9371012

9381013
return LoadBalancersResourceWithStreamingResponse(self._client.load_balancers)
@@ -987,10 +1062,16 @@ def regions(self) -> regions.AsyncRegionsResourceWithStreamingResponse:
9871062
return AsyncRegionsResourceWithStreamingResponse(self._client.regions)
9881063

9891064
@cached_property
990-
def knowledge_bases(self) -> knowledge_bases.AsyncKnowledgeBasesResourceWithStreamingResponse:
991-
from .resources.knowledge_bases import AsyncKnowledgeBasesResourceWithStreamingResponse
1065+
def knowledge_bases(
1066+
self,
1067+
) -> knowledge_bases.AsyncKnowledgeBasesResourceWithStreamingResponse:
1068+
from .resources.knowledge_bases import (
1069+
AsyncKnowledgeBasesResourceWithStreamingResponse,
1070+
)
9921071

993-
return AsyncKnowledgeBasesResourceWithStreamingResponse(self._client.knowledge_bases)
1072+
return AsyncKnowledgeBasesResourceWithStreamingResponse(
1073+
self._client.knowledge_bases
1074+
)
9941075

9951076
@cached_property
9961077
def inference(self) -> inference.AsyncInferenceResourceWithStreamingResponse:
@@ -1017,8 +1098,12 @@ def firewalls(self) -> firewalls.AsyncFirewallsResourceWithStreamingResponse:
10171098
return AsyncFirewallsResourceWithStreamingResponse(self._client.firewalls)
10181099

10191100
@cached_property
1020-
def floating_ips(self) -> floating_ips.AsyncFloatingIPsResourceWithStreamingResponse:
1021-
from .resources.floating_ips import AsyncFloatingIPsResourceWithStreamingResponse
1101+
def floating_ips(
1102+
self,
1103+
) -> floating_ips.AsyncFloatingIPsResourceWithStreamingResponse:
1104+
from .resources.floating_ips import (
1105+
AsyncFloatingIPsResourceWithStreamingResponse,
1106+
)
10221107

10231108
return AsyncFloatingIPsResourceWithStreamingResponse(self._client.floating_ips)
10241109

@@ -1029,10 +1114,16 @@ def images(self) -> images.AsyncImagesResourceWithStreamingResponse:
10291114
return AsyncImagesResourceWithStreamingResponse(self._client.images)
10301115

10311116
@cached_property
1032-
def load_balancers(self) -> load_balancers.AsyncLoadBalancersResourceWithStreamingResponse:
1033-
from .resources.load_balancers import AsyncLoadBalancersResourceWithStreamingResponse
1117+
def load_balancers(
1118+
self,
1119+
) -> load_balancers.AsyncLoadBalancersResourceWithStreamingResponse:
1120+
from .resources.load_balancers import (
1121+
AsyncLoadBalancersResourceWithStreamingResponse,
1122+
)
10341123

1035-
return AsyncLoadBalancersResourceWithStreamingResponse(self._client.load_balancers)
1124+
return AsyncLoadBalancersResourceWithStreamingResponse(
1125+
self._client.load_balancers
1126+
)
10361127

10371128
@cached_property
10381129
def sizes(self) -> sizes.AsyncSizesResourceWithStreamingResponse:

0 commit comments

Comments
 (0)