14
14
15
15
from __future__ import annotations
16
16
17
+ from contextlib import contextmanager
17
18
from typing import (
18
19
TYPE_CHECKING ,
19
20
Any ,
21
+ Awaitable ,
20
22
Callable ,
21
23
MutableSequence ,
22
24
)
@@ -87,17 +89,17 @@ def _extract_params(
87
89
)
88
90
89
91
90
- def generate_content_create (
91
- tracer : Tracer , event_logger : EventLogger , capture_content : bool
92
- ):
93
- """Wrap the `generate_content` method of the `GenerativeModel` class to trace it."""
92
+ class MethodWrappers :
93
+ def __init__ (
94
+ self , tracer : Tracer , event_logger : EventLogger , capture_content : bool
95
+ ) -> None :
96
+ self .tracer = tracer
97
+ self .event_logger = event_logger
98
+ self .capture_content = capture_content
94
99
95
- def traced_method (
96
- wrapped : Callable [
97
- ...,
98
- prediction_service .GenerateContentResponse
99
- | prediction_service_v1beta1 .GenerateContentResponse ,
100
- ],
100
+ @contextmanager
101
+ def _with_instrumentation (
102
+ self ,
101
103
instance : client .PredictionServiceClient
102
104
| client_v1beta1 .PredictionServiceClient ,
103
105
args : Any ,
@@ -111,32 +113,82 @@ def traced_method(
111
113
}
112
114
113
115
span_name = get_span_name (span_attributes )
114
- with tracer .start_as_current_span (
116
+
117
+ with self .tracer .start_as_current_span (
115
118
name = span_name ,
116
119
kind = SpanKind .CLIENT ,
117
120
attributes = span_attributes ,
118
121
) as span :
119
122
for event in request_to_events (
120
- params = params , capture_content = capture_content
123
+ params = params , capture_content = self . capture_content
121
124
):
122
- event_logger .emit (event )
125
+ self . event_logger .emit (event )
123
126
124
127
# TODO: set error.type attribute
125
128
# https://github.com/open-telemetry/semantic-conventions/blob/main/docs/gen-ai/gen-ai-spans.md
126
- response = wrapped (* args , ** kwargs )
127
- # TODO: handle streaming
128
- # if is_streaming(kwargs):
129
- # return StreamWrapper(
130
- # result, span, event_logger, capture_content
131
- # )
132
-
133
- if span .is_recording ():
134
- span .set_attributes (get_genai_response_attributes (response ))
135
- for event in response_to_events (
136
- response = response , capture_content = capture_content
137
- ):
138
- event_logger .emit (event )
139
129
130
+ def handle_response (
131
+ response : prediction_service .GenerateContentResponse
132
+ | prediction_service_v1beta1 .GenerateContentResponse ,
133
+ ) -> None :
134
+ if span .is_recording ():
135
+ # When streaming, this is called multiple times so attributes would be
136
+ # overwritten. In practice, it looks the API only returns the interesting
137
+ # attributes on the last streamed response. However, I couldn't find
138
+ # documentation for this and setting attributes shouldn't be too expensive.
139
+ span .set_attributes (
140
+ get_genai_response_attributes (response )
141
+ )
142
+
143
+ for event in response_to_events (
144
+ response = response , capture_content = self .capture_content
145
+ ):
146
+ self .event_logger .emit (event )
147
+
148
+ yield handle_response
149
+
150
+ def generate_content (
151
+ self ,
152
+ wrapped : Callable [
153
+ ...,
154
+ prediction_service .GenerateContentResponse
155
+ | prediction_service_v1beta1 .GenerateContentResponse ,
156
+ ],
157
+ instance : client .PredictionServiceClient
158
+ | client_v1beta1 .PredictionServiceClient ,
159
+ args : Any ,
160
+ kwargs : Any ,
161
+ ) -> (
162
+ prediction_service .GenerateContentResponse
163
+ | prediction_service_v1beta1 .GenerateContentResponse
164
+ ):
165
+ with self ._with_instrumentation (
166
+ instance , args , kwargs
167
+ ) as handle_response :
168
+ response = wrapped (* args , ** kwargs )
169
+ handle_response (response )
140
170
return response
141
171
142
- return traced_method
172
+ async def agenerate_content (
173
+ self ,
174
+ wrapped : Callable [
175
+ ...,
176
+ Awaitable [
177
+ prediction_service .GenerateContentResponse
178
+ | prediction_service_v1beta1 .GenerateContentResponse
179
+ ],
180
+ ],
181
+ instance : client .PredictionServiceClient
182
+ | client_v1beta1 .PredictionServiceClient ,
183
+ args : Any ,
184
+ kwargs : Any ,
185
+ ) -> (
186
+ prediction_service .GenerateContentResponse
187
+ | prediction_service_v1beta1 .GenerateContentResponse
188
+ ):
189
+ with self ._with_instrumentation (
190
+ instance , args , kwargs
191
+ ) as handle_response :
192
+ response = await wrapped (* args , ** kwargs )
193
+ handle_response (response )
194
+ return response
0 commit comments