24
24
import java .util .HashMap ;
25
25
import java .util .List ;
26
26
import java .util .Map ;
27
+ import java .util .stream .Collectors ;
27
28
import javax .annotation .Nullable ;
28
29
import software .amazon .awssdk .core .SdkRequest ;
29
30
import software .amazon .awssdk .core .SdkResponse ;
30
31
import software .amazon .awssdk .core .async .SdkPublisher ;
31
32
import software .amazon .awssdk .core .document .Document ;
32
33
import software .amazon .awssdk .protocols .json .SdkJsonGenerator ;
34
+ import software .amazon .awssdk .protocols .jsoncore .JsonNode ;
35
+ import software .amazon .awssdk .protocols .jsoncore .JsonNodeParser ;
33
36
import software .amazon .awssdk .services .bedrockruntime .BedrockRuntimeAsyncClient ;
34
37
import software .amazon .awssdk .services .bedrockruntime .model .ContentBlock ;
38
+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDelta ;
39
+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockDeltaEvent ;
40
+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockStartEvent ;
41
+ import software .amazon .awssdk .services .bedrockruntime .model .ContentBlockStopEvent ;
35
42
import software .amazon .awssdk .services .bedrockruntime .model .ConverseRequest ;
36
43
import software .amazon .awssdk .services .bedrockruntime .model .ConverseResponse ;
37
44
import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamMetadataEvent ;
41
48
import software .amazon .awssdk .services .bedrockruntime .model .ConverseStreamResponseHandler ;
42
49
import software .amazon .awssdk .services .bedrockruntime .model .InferenceConfiguration ;
43
50
import software .amazon .awssdk .services .bedrockruntime .model .Message ;
51
+ import software .amazon .awssdk .services .bedrockruntime .model .MessageStartEvent ;
44
52
import software .amazon .awssdk .services .bedrockruntime .model .MessageStopEvent ;
45
53
import software .amazon .awssdk .services .bedrockruntime .model .StopReason ;
46
54
import software .amazon .awssdk .services .bedrockruntime .model .TokenUsage ;
47
55
import software .amazon .awssdk .services .bedrockruntime .model .ToolResultContentBlock ;
48
56
import software .amazon .awssdk .services .bedrockruntime .model .ToolUseBlock ;
57
+ import software .amazon .awssdk .services .bedrockruntime .model .ToolUseBlockStart ;
49
58
import software .amazon .awssdk .thirdparty .jackson .core .JsonFactory ;
50
59
51
60
/**
@@ -59,6 +68,8 @@ private BedrockRuntimeImpl() {}
59
68
private static final AttributeKey <String > GEN_AI_SYSTEM = stringKey ("gen_ai.system" );
60
69
61
70
private static final JsonFactory JSON_FACTORY = new JsonFactory ();
71
+ private static final JsonNodeParser JSON_PARSER = JsonNode .parser ();
72
+ private static final DocumentUnmarshaller DOCUMENT_UNMARSHALLER = new DocumentUnmarshaller ();
62
73
63
74
static boolean isBedrockRuntimeRequest (SdkRequest request ) {
64
75
if (request instanceof ConverseRequest ) {
@@ -202,35 +213,54 @@ static Long getUsageOutputTokens(Response response) {
202
213
static void recordRequestEvents (
203
214
Context otelContext , Logger eventLogger , SdkRequest request , boolean captureMessageContent ) {
204
215
if (request instanceof ConverseRequest ) {
205
- for (Message message : ((ConverseRequest ) request ).messages ()) {
206
- long numToolResults =
207
- message .content ().stream ().filter (block -> block .toolResult () != null ).count ();
208
- if (numToolResults > 0 ) {
209
- // Tool results are different from others, emitting multiple events for a single message,
210
- // so treat them separately.
211
- emitToolResultEvents (otelContext , eventLogger , message , captureMessageContent );
212
- if (numToolResults == message .content ().size ()) {
213
- continue ;
214
- }
215
- // There are content blocks besides tool results in the same message. While models
216
- // generally don't expect such usage, the SDK allows it so go ahead and generate a normal
217
- // message too.
218
- }
219
- LogRecordBuilder event = newEvent (otelContext , eventLogger );
220
- switch (message .role ()) {
221
- case ASSISTANT :
222
- event .setAttribute (EVENT_NAME , "gen_ai.assistant.message" );
223
- break ;
224
- case USER :
225
- event .setAttribute (EVENT_NAME , "gen_ai.user.message" );
226
- break ;
227
- default :
228
- // unknown role, shouldn't happen in practice
229
- continue ;
216
+ recordRequestMessageEvents (
217
+ otelContext , eventLogger , ((ConverseRequest ) request ).messages (), captureMessageContent );
218
+ }
219
+ if (request instanceof ConverseStreamRequest ) {
220
+ recordRequestMessageEvents (
221
+ otelContext ,
222
+ eventLogger ,
223
+ ((ConverseStreamRequest ) request ).messages (),
224
+ captureMessageContent );
225
+
226
+ // Good a time as any to store the context for a streaming request.
227
+ TracingConverseStreamResponseHandler .fromContext (otelContext ).setOtelContext (otelContext );
228
+ }
229
+ }
230
+
231
+ private static void recordRequestMessageEvents (
232
+ Context otelContext ,
233
+ Logger eventLogger ,
234
+ List <Message > messages ,
235
+ boolean captureMessageContent ) {
236
+ for (Message message : messages ) {
237
+ long numToolResults =
238
+ message .content ().stream ().filter (block -> block .toolResult () != null ).count ();
239
+ if (numToolResults > 0 ) {
240
+ // Tool results are different from others, emitting multiple events for a single message,
241
+ // so treat them separately.
242
+ emitToolResultEvents (otelContext , eventLogger , message , captureMessageContent );
243
+ if (numToolResults == message .content ().size ()) {
244
+ continue ;
230
245
}
231
- // Requests don't have index or stop reason.
232
- event .setBody (convertMessage (message , -1 , null , captureMessageContent )).emit ();
246
+ // There are content blocks besides tool results in the same message. While models
247
+ // generally don't expect such usage, the SDK allows it so go ahead and generate a normal
248
+ // message too.
233
249
}
250
+ LogRecordBuilder event = newEvent (otelContext , eventLogger );
251
+ switch (message .role ()) {
252
+ case ASSISTANT :
253
+ event .setAttribute (EVENT_NAME , "gen_ai.assistant.message" );
254
+ break ;
255
+ case USER :
256
+ event .setAttribute (EVENT_NAME , "gen_ai.user.message" );
257
+ break ;
258
+ default :
259
+ // unknown role, shouldn't happen in practice
260
+ continue ;
261
+ }
262
+ // Requests don't have index or stop reason.
263
+ event .setBody (convertMessage (message , -1 , null , captureMessageContent )).emit ();
234
264
}
235
265
}
236
266
@@ -248,7 +278,7 @@ static void recordResponseEvents(
248
278
convertMessage (
249
279
converseResponse .output ().message (),
250
280
0 ,
251
- converseResponse .stopReason (),
281
+ converseResponse .stopReasonAsString (),
252
282
captureMessageContent ))
253
283
.emit ();
254
284
}
@@ -270,7 +300,8 @@ private static Double floatToDouble(Float value) {
270
300
return Double .valueOf (value );
271
301
}
272
302
273
- public static BedrockRuntimeAsyncClient wrap (BedrockRuntimeAsyncClient asyncClient ) {
303
+ public static BedrockRuntimeAsyncClient wrap (
304
+ BedrockRuntimeAsyncClient asyncClient , Logger eventLogger , boolean captureMessageContent ) {
274
305
// proxy BedrockRuntimeAsyncClient so we can wrap the subscriber to converseStream to capture
275
306
// events.
276
307
return (BedrockRuntimeAsyncClient )
@@ -283,7 +314,9 @@ public static BedrockRuntimeAsyncClient wrap(BedrockRuntimeAsyncClient asyncClie
283
314
&& args [1 ] instanceof ConverseStreamResponseHandler ) {
284
315
TracingConverseStreamResponseHandler wrapped =
285
316
new TracingConverseStreamResponseHandler (
286
- (ConverseStreamResponseHandler ) args [1 ]);
317
+ (ConverseStreamResponseHandler ) args [1 ],
318
+ eventLogger ,
319
+ captureMessageContent );
287
320
args [1 ] = wrapped ;
288
321
try (Scope ignored = wrapped .makeCurrent ()) {
289
322
return invokeProxyMethod (method , asyncClient , args );
@@ -318,12 +351,29 @@ public static TracingConverseStreamResponseHandler fromContext(Context context)
318
351
ContextKey .named ("bedrock-runtime-converse-stream-response-handler" );
319
352
320
353
private final ConverseStreamResponseHandler delegate ;
354
+ private final Logger eventLogger ;
355
+ private final boolean captureMessageContent ;
356
+
357
+ private StringBuilder currentText ;
358
+
359
+ // The response handler is created and stored into context before the span, so we need to
360
+ // also pass the later context in for recording events. While subscribers are called from a
361
+ // single thread, it is not clear if that is guaranteed to be the same as the execution
362
+ // interceptor so we use volatile.
363
+ private volatile Context otelContext ;
364
+
365
+ private List <ToolUseBlock > tools ;
366
+ private ToolUseBlock .Builder currentTool ;
367
+ private StringBuilder currentToolArgs ;
321
368
322
369
List <String > stopReasons ;
323
370
TokenUsage usage ;
324
371
325
- TracingConverseStreamResponseHandler (ConverseStreamResponseHandler delegate ) {
372
+ TracingConverseStreamResponseHandler (
373
+ ConverseStreamResponseHandler delegate , Logger eventLogger , boolean captureMessageContent ) {
326
374
this .delegate = delegate ;
375
+ this .eventLogger = eventLogger ;
376
+ this .captureMessageContent = captureMessageContent ;
327
377
}
328
378
329
379
@ Override
@@ -336,19 +386,66 @@ public void onEventStream(SdkPublisher<ConverseStreamOutput> sdkPublisher) {
336
386
delegate .onEventStream (
337
387
sdkPublisher .map (
338
388
event -> {
339
- if (event instanceof MessageStopEvent ) {
340
- if (stopReasons == null ) {
341
- stopReasons = new ArrayList <>();
342
- }
343
- stopReasons .add (((MessageStopEvent ) event ).stopReasonAsString ());
344
- }
345
- if (event instanceof ConverseStreamMetadataEvent ) {
346
- usage = ((ConverseStreamMetadataEvent ) event ).usage ();
347
- }
389
+ handleEvent (event );
348
390
return event ;
349
391
}));
350
392
}
351
393
394
+ private void handleEvent (ConverseStreamOutput event ) {
395
+ if (captureMessageContent && event instanceof MessageStartEvent ) {
396
+ if (currentText == null ) {
397
+ currentText = new StringBuilder ();
398
+ }
399
+ currentText .setLength (0 );
400
+ }
401
+ if (event instanceof ContentBlockStartEvent ) {
402
+ ToolUseBlockStart toolUse = ((ContentBlockStartEvent ) event ).start ().toolUse ();
403
+ if (toolUse != null ) {
404
+ if (currentToolArgs == null ) {
405
+ currentToolArgs = new StringBuilder ();
406
+ }
407
+ currentToolArgs .setLength (0 );
408
+ currentTool = ToolUseBlock .builder ().name (toolUse .name ()).toolUseId (toolUse .toolUseId ());
409
+ }
410
+ }
411
+ if (event instanceof ContentBlockDeltaEvent ) {
412
+ ContentBlockDelta delta = ((ContentBlockDeltaEvent ) event ).delta ();
413
+ if (captureMessageContent && delta .text () != null ) {
414
+ currentText .append (delta .text ());
415
+ }
416
+ if (delta .toolUse () != null ) {
417
+ currentToolArgs .append (delta .toolUse ().input ());
418
+ }
419
+ }
420
+ if (event instanceof ContentBlockStopEvent ) {
421
+ if (currentTool != null ) {
422
+ if (tools == null ) {
423
+ tools = new ArrayList <>();
424
+ }
425
+ if (currentToolArgs != null ) {
426
+ Document args = deserializeDocument (currentToolArgs .toString ());
427
+ currentTool .input (args );
428
+ }
429
+ tools .add (currentTool .build ());
430
+ currentTool = null ;
431
+ }
432
+ }
433
+ if (event instanceof MessageStopEvent ) {
434
+ if (stopReasons == null ) {
435
+ stopReasons = new ArrayList <>();
436
+ }
437
+ String stopReason = ((MessageStopEvent ) event ).stopReasonAsString ();
438
+ stopReasons .add (stopReason );
439
+ newEvent (otelContext , eventLogger )
440
+ .setAttribute (EVENT_NAME , "gen_ai.choice" )
441
+ .setBody (convertMessageData (currentText , tools , 0 , stopReason , captureMessageContent ))
442
+ .emit ();
443
+ }
444
+ if (event instanceof ConverseStreamMetadataEvent ) {
445
+ usage = ((ConverseStreamMetadataEvent ) event ).usage ();
446
+ }
447
+ }
448
+
352
449
@ Override
353
450
public void exceptionOccurred (Throwable throwable ) {
354
451
delegate .exceptionOccurred (throwable );
@@ -363,6 +460,10 @@ public void complete() {
363
460
public Context storeInContext (Context context ) {
364
461
return context .with (KEY , this );
365
462
}
463
+
464
+ void setOtelContext (Context otelContext ) {
465
+ this .otelContext = otelContext ;
466
+ }
366
467
}
367
468
368
469
private static LogRecordBuilder newEvent (Context otelContext , Logger eventLogger ) {
@@ -401,9 +502,9 @@ private static void emitToolResultEvents(
401
502
}
402
503
403
504
private static Value <?> convertMessage (
404
- Message message , int index , @ Nullable StopReason stopReason , boolean captureMessageContent ) {
505
+ Message message , int index , @ Nullable String stopReason , boolean captureMessageContent ) {
405
506
StringBuilder text = null ;
406
- List <Value <?> > toolCalls = null ;
507
+ List <ToolUseBlock > toolCalls = null ;
407
508
for (ContentBlock content : message .content ()) {
408
509
if (captureMessageContent && content .text () != null ) {
409
510
if (text == null ) {
@@ -415,15 +516,29 @@ private static Value<?> convertMessage(
415
516
if (toolCalls == null ) {
416
517
toolCalls = new ArrayList <>();
417
518
}
418
- toolCalls .add (convertToolCall ( content .toolUse (), captureMessageContent ));
519
+ toolCalls .add (content .toolUse ());
419
520
}
420
521
}
522
+
523
+ return convertMessageData (text , toolCalls , index , stopReason , captureMessageContent );
524
+ }
525
+
526
+ private static Value <?> convertMessageData (
527
+ @ Nullable StringBuilder text ,
528
+ List <ToolUseBlock > toolCalls ,
529
+ int index ,
530
+ @ Nullable String stopReason ,
531
+ boolean captureMessageContent ) {
421
532
Map <String , Value <?>> body = new HashMap <>();
422
533
if (text != null ) {
423
534
body .put ("content" , Value .of (text .toString ()));
424
535
}
425
536
if (toolCalls != null ) {
426
- body .put ("toolCalls" , Value .of (toolCalls ));
537
+ List <Value <?>> toolCallValues =
538
+ toolCalls .stream ()
539
+ .map (tool -> convertToolCall (tool , captureMessageContent ))
540
+ .collect (Collectors .toList ());
541
+ body .put ("toolCalls" , Value .of (toolCallValues ));
427
542
}
428
543
if (stopReason != null ) {
429
544
body .put ("finish_reason" , Value .of (stopReason .toString ()));
@@ -451,4 +566,9 @@ private static String serializeDocument(Document document) {
451
566
document .accept (marshaller );
452
567
return new String (generator .getBytes (), StandardCharsets .UTF_8 );
453
568
}
569
+
570
+ private static Document deserializeDocument (String json ) {
571
+ JsonNode node = JSON_PARSER .parse (json );
572
+ return node .visit (DOCUMENT_UNMARSHALLER );
573
+ }
454
574
}
0 commit comments