diff --git a/docs/my-website/docs/observability/braintrust.md b/docs/my-website/docs/observability/braintrust.md index e6b4fe769bc4..645ce074ca52 100644 --- a/docs/my-website/docs/observability/braintrust.md +++ b/docs/my-website/docs/observability/braintrust.md @@ -75,6 +75,12 @@ It is recommended that you include the `project_id` or `project_name` to ensure You can customize the span name in Braintrust logging by passing `span_name` in the metadata. By default, the span name is set to "Chat Completion". +### Custom Span Attributes + +You can customize the span id, root span name and span parents in Braintrust logging by passing `span_id`, `root_span_id` and `span_parents` in the metadata. +`span_parents` should be a string containing a list of span ids, joined by , + + diff --git a/litellm/integrations/braintrust_logging.py b/litellm/integrations/braintrust_logging.py index 5bc6afb6dbc1..364fa3f5defd 100644 --- a/litellm/integrations/braintrust_logging.py +++ b/litellm/integrations/braintrust_logging.py @@ -206,6 +206,20 @@ def log_success_event( # noqa: PLR0915 # Allow metadata override for span name span_name = dynamic_metadata.get("span_name", "Chat Completion") + + # Span parents is a special case + span_parents = dynamic_metadata.get("span_parents") + + # Convert comma-separated string to list if present + if span_parents: + span_parents = [s.strip() for s in span_parents.split(",") if s.strip()] + + # Add optional span attributes only if present + span_attributes = { + "span_id": dynamic_metadata.get("span_id"), + "root_span_id": dynamic_metadata.get("root_span_id"), + "span_parents": span_parents, + } request_data = { "id": litellm_call_id, @@ -214,6 +228,12 @@ def log_success_event( # noqa: PLR0915 "tags": tags, "span_attributes": {"name": span_name, "type": "llm"}, } + + # Only add those that are not None (or falsy) + for key, value in span_attributes.items(): + if value: + request_data[key] = value + if choices is not None: request_data["output"] = [choice.dict() for choice in choices] else: diff --git a/tests/test_litellm/integrations/test_braintrust_span_name.py b/tests/test_litellm/integrations/test_braintrust_span_name.py index 30381e997839..7050a6d355f1 100644 --- a/tests/test_litellm/integrations/test_braintrust_span_name.py +++ b/tests/test_litellm/integrations/test_braintrust_span_name.py @@ -224,6 +224,76 @@ async def test_async_custom_span_name(self, mock_get_http_handler): json_data["events"][0]["span_attributes"]["name"], "Async Custom Operation" ) + @patch('litellm.integrations.braintrust_logging.HTTPHandler') + def test_span_attributes_with_multiple_metadata_fields(self, MockHTTPHandler): + """Test that span_name works correctly alongside other metadata fields.""" + # Mock HTTP response + mock_response = Mock() + mock_response.json.return_value = {"id": "test-project-id"} + mock_http_handler = Mock() + mock_http_handler.post.return_value = mock_response + MockHTTPHandler.return_value = mock_http_handler + + # Setup + logger = BraintrustLogger(api_key="test-key") + logger.default_project_id = "test-project-id" + + # Create a mock response object + message_mock = Mock() + message_mock.json = Mock(return_value={"content": "test"}) + + choice_mock = Mock() + choice_mock.message = message_mock + choice_mock.dict = Mock(return_value={"message": {"content": "test"}}) + choice_mock.__getitem__ = Mock(return_value=message_mock) + + response_obj = Mock(spec=litellm.ModelResponse) + response_obj.choices = [choice_mock] + response_obj.__getitem__ = Mock(return_value=[choice_mock]) + response_obj.usage = litellm.Usage( + prompt_tokens=10, + completion_tokens=20, + total_tokens=30 + ) + + kwargs = { + "litellm_call_id": "test-call-id", + "messages": [{"role": "user", "content": "test"}], + "litellm_params": { + "metadata": { + "span_name": "Multi Metadata Test", + "span_id": "span_id", + "root_span_id": "root_span_id", + "span_parents": "span_parent1,span_parent2", + "project_id": "custom-project", + "user_id": "user123", + "session_id": "session456" + } + }, + "model": "gpt-3.5-turbo", + "response_cost": 0.001 + } + + # Execute + logger.log_success_event(kwargs, response_obj, datetime.now(), datetime.now()) + + # Verify + call_args = mock_http_handler.post.call_args + self.assertIsNotNone(call_args) + json_data = call_args.kwargs['json'] + + # Check span name + self.assertEqual(json_data['events'][0]['span_attributes']['name'], 'Multi Metadata Test') + self.assertEqual(json_data['events'][0]['span_id'], 'span_id') + self.assertEqual(json_data['events'][0]['root_span_id'], 'root_span_id') + self.assertEqual(json_data['events'][0]['span_parents'][0], 'span_parent1') + self.assertEqual(json_data['events'][0]['span_parents'][1], 'span_parent2') + + # Check that other metadata is preserved + event_metadata = json_data['events'][0]['metadata'] + self.assertEqual(event_metadata['user_id'], 'user123') + self.assertEqual(event_metadata['session_id'], 'session456') + if __name__ == "__main__": unittest.main()