diff --git a/fastdeploy/collect_env.py b/fastdeploy/collect_env.py index ede1f3753e5..4a1ae115a13 100644 --- a/fastdeploy/collect_env.py +++ b/fastdeploy/collect_env.py @@ -561,7 +561,7 @@ def get_env_info(): if PADDLE_AVAILABLE: paddle_version_str = paddle.__version__ - paddle_cuda_available_str = str(torch.cuda.is_available()) + paddle_cuda_available_str = str(paddle.device.is_compiled_with_cuda()) paddle_cuda_version_str = str(paddle.version.cuda()) else: version_str = paddle_cuda_available_str = cuda_version_str = "N/A" diff --git a/fastdeploy/engine/common_engine.py b/fastdeploy/engine/common_engine.py index bfd5e83d411..830624462b5 100644 --- a/fastdeploy/engine/common_engine.py +++ b/fastdeploy/engine/common_engine.py @@ -1060,6 +1060,7 @@ def _exit_sub_services(self): """ exit sub services """ + llm_logger.info("Exit sub services.....") self.running = False if hasattr(self, "engine_worker_queue_server") and self.engine_worker_queue_server is not None: self.engine_worker_queue_server.cleanup() diff --git a/fastdeploy/entrypoints/openai/protocol.py b/fastdeploy/entrypoints/openai/protocol.py index 7329af8a4f0..bc44dcad3be 100644 --- a/fastdeploy/entrypoints/openai/protocol.py +++ b/fastdeploy/entrypoints/openai/protocol.py @@ -66,6 +66,7 @@ class CompletionTokenUsageInfo(BaseModel): """ reasoning_tokens: Optional[int] = None + image_tokens: Optional[int] = None class PromptTokenUsageInfo(BaseModel): @@ -74,6 +75,8 @@ class PromptTokenUsageInfo(BaseModel): """ cached_tokens: Optional[int] = None + image_tokens: Optional[int] = None + video_tokens: Optional[int] = None class UsageInfo(BaseModel): diff --git a/fastdeploy/entrypoints/openai/response_processors.py b/fastdeploy/entrypoints/openai/response_processors.py index b340133d6b7..95a5e3ec404 100644 --- a/fastdeploy/entrypoints/openai/response_processors.py +++ b/fastdeploy/entrypoints/openai/response_processors.py @@ -16,6 +16,7 @@ from typing import Any, List, Optional +from fastdeploy.entrypoints.openai.usage_calculator import count_tokens from fastdeploy.input.tokenzier_client import AsyncTokenizerClient, ImageDecodeRequest from fastdeploy.utils import api_server_logger @@ -104,6 +105,7 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking, image_output = self._end_image_code_request_output image_output["outputs"]["multipart"] = [image] image_output["outputs"]["token_ids"] = all_tokens + image_output["outputs"]["num_image_tokens"] = count_tokens(all_tokens) yield image_output self.data_processor.process_response_dict( @@ -124,6 +126,7 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking, token_ids = request_output["outputs"]["token_ids"] if token_ids[-1] == self.eos_token_id: multipart = [] + num_image_tokens = 0 for part in self._multipart_buffer: if part["decode_type"] == 0: self.data_processor.process_response_dict( @@ -139,6 +142,7 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking, if self.decoder_client: req_id = part["request_output"]["request_id"] all_tokens = part["request_output"]["outputs"]["token_ids"] + num_image_tokens += count_tokens(all_tokens) image_ret = await self.decoder_client.decode_image( request=ImageDecodeRequest(req_id=req_id, data=all_tokens) @@ -150,4 +154,5 @@ async def process_response_chat(self, request_outputs, stream, enable_thinking, lasrt_request_output = self._multipart_buffer[-1]["request_output"] lasrt_request_output["outputs"]["multipart"] = multipart + lasrt_request_output["outputs"]["num_image_tokens"] = num_image_tokens yield lasrt_request_output diff --git a/fastdeploy/entrypoints/openai/serving_chat.py b/fastdeploy/entrypoints/openai/serving_chat.py index 6e184739eca..802c3e8dc10 100644 --- a/fastdeploy/entrypoints/openai/serving_chat.py +++ b/fastdeploy/entrypoints/openai/serving_chat.py @@ -189,6 +189,8 @@ async def chat_completion_stream_generator( previous_num_tokens = [0] * num_choices reasoning_num_tokens = [0] * num_choices num_prompt_tokens = 0 + num_cached_tokens = 0 + num_image_tokens = [0] * num_choices tool_called = [False] * num_choices max_streaming_response_tokens = ( request.max_streaming_response_tokens @@ -321,6 +323,9 @@ async def chat_completion_stream_generator( output_top_logprobs = output["top_logprobs"] output_draft_top_logprobs = output["draft_top_logprobs"] previous_num_tokens[idx] += len(output["token_ids"]) + if output.get("num_image_tokens"): + previous_num_tokens[idx] += output.get("num_image_tokens") + num_image_tokens[idx] += output.get("num_image_tokens") reasoning_num_tokens[idx] += output.get("reasoning_token_num", 0) logprobs_res: Optional[LogProbs] = None draft_logprobs_res: Optional[LogProbs] = None @@ -389,8 +394,10 @@ async def chat_completion_stream_generator( prompt_tokens=num_prompt_tokens, completion_tokens=previous_num_tokens[idx], total_tokens=num_prompt_tokens + previous_num_tokens[idx], + prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens), completion_tokens_details=CompletionTokenUsageInfo( - reasoning_tokens=reasoning_num_tokens[idx] + reasoning_tokens=reasoning_num_tokens[idx], + image_tokens=num_image_tokens[idx], ), ) choices.append(choice) @@ -409,7 +416,10 @@ async def chat_completion_stream_generator( prompt_tokens=num_prompt_tokens, completion_tokens=completion_tokens, total_tokens=num_prompt_tokens + completion_tokens, - completion_tokens_details=CompletionTokenUsageInfo(reasoning_tokens=reasoning_tokens), + prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cached_tokens), + completion_tokens_details=CompletionTokenUsageInfo( + image_tokens=sum(num_image_tokens), reasoning_tokens=reasoning_tokens + ), ) chunk = ChatCompletionStreamResponse( id=request_id, @@ -466,6 +476,7 @@ async def chat_completion_full_generator( draft_logprob_contents = [[] for _ in range(num_choices)] completion_token_ids = [[] for _ in range(num_choices)] num_cached_tokens = [0] * num_choices + num_image_tokens = [0] * num_choices response_processor = ChatResponseProcessor( data_processor=self.engine_client.data_processor, enable_mm_output=self.enable_mm_output, @@ -531,6 +542,9 @@ async def chat_completion_full_generator( if data["finished"]: num_choices -= 1 reasoning_num_tokens[idx] = data["outputs"].get("reasoning_token_num", 0) + if data["outputs"].get("image_token_num"): + previous_num_tokens[idx] += data["outputs"].get("image_token_num") + num_image_tokens[idx] = data["outputs"].get("image_token_num") choice = await self._create_chat_completion_choice( output=output, index=idx, @@ -540,6 +554,7 @@ async def chat_completion_full_generator( prompt_tokens=prompt_tokens, completion_token_ids=completion_token_ids[idx], num_cached_tokens=num_cached_tokens, + num_image_tokens=num_image_tokens, logprob_contents=logprob_contents, response_processor=response_processor, ) @@ -557,7 +572,9 @@ async def chat_completion_full_generator( completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=sum(num_cached_tokens)), - completion_tokens_details=CompletionTokenUsageInfo(reasoning_tokens=num_reasoning_tokens), + completion_tokens_details=CompletionTokenUsageInfo( + reasoning_tokens=num_reasoning_tokens, image_tokens=sum(num_image_tokens) + ), ) choices = sorted(choices, key=lambda x: x.index) res = ChatCompletionResponse( @@ -580,6 +597,7 @@ async def _create_chat_completion_choice( prompt_tokens: str, completion_token_ids: list, num_cached_tokens: list, + num_image_tokens: list, logprob_contents: list, response_processor: ChatResponseProcessor, ) -> ChatCompletionResponseChoice: @@ -609,6 +627,7 @@ async def _create_chat_completion_choice( has_no_token_limit = request.max_tokens is None and request.max_completion_tokens is None max_tokens = request.max_completion_tokens or request.max_tokens num_cached_tokens[index] = output.get("num_cached_tokens", 0) + num_image_tokens[index] = output.get("num_image_tokens", 0) finish_reason = "stop" if has_no_token_limit or previous_num_tokens != max_tokens: diff --git a/fastdeploy/entrypoints/openai/serving_completion.py b/fastdeploy/entrypoints/openai/serving_completion.py index 5ce12e87b4b..c27375305ed 100644 --- a/fastdeploy/entrypoints/openai/serving_completion.py +++ b/fastdeploy/entrypoints/openai/serving_completion.py @@ -33,6 +33,7 @@ CompletionTokenUsageInfo, ErrorInfo, ErrorResponse, + PromptTokenUsageInfo, UsageInfo, ) from fastdeploy.utils import ( @@ -370,6 +371,8 @@ async def completion_stream_generator( req_id = f"{request_id}_{i}" dealer.write([b"", req_id.encode("utf-8")]) # 发送多路请求 output_tokens = [0] * num_choices + num_cache_tokens = [0] * num_choices + num_image_tokens = [0] * num_choices inference_start_time = [0] * num_choices reasoning_tokens = [0] * num_choices first_iteration = [True] * num_choices @@ -459,7 +462,11 @@ async def completion_stream_generator( draft_logprobs_res = self._create_completion_logprobs( output_draft_top_logprobs, request.logprobs, 0 ) - output_tokens[idx] += 1 + output_tokens[idx] += len(output.get("token_ids", [])) or 0 + num_cache_tokens[idx] += output.get("num_cache_tokens") or 0 + if output.get("num_image_tokens"): + output_tokens[idx] += output.get("num_image_tokens") + num_image_tokens[idx] += output.get("num_image_tokens") reasoning_tokens[idx] += output.get("reasoning_token_num", 0) delta_message = CompletionResponseStreamChoice( index=idx, @@ -527,8 +534,9 @@ async def completion_stream_generator( prompt_batched_token_ids[idx // (1 if request.n is None else request.n)] ) + output_tokens[idx], + prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cache_tokens[idx]), completion_tokens_details=CompletionTokenUsageInfo( - reasoning_tokens=reasoning_tokens[idx] + image_tokens=num_image_tokens[idx], reasoning_tokens=reasoning_tokens[idx] ), ), ) @@ -559,6 +567,8 @@ def request_output_to_completion_response( choices: List[CompletionResponseChoice] = [] num_prompt_tokens = 0 num_generated_tokens = 0 + num_cache_tokens = 0 + num_image_tokens = 0 num_reasoning_tokens = 0 for idx in range(len(final_res_batch)): @@ -614,6 +624,10 @@ def request_output_to_completion_response( num_generated_tokens += final_res["output_token_ids"] num_prompt_tokens += len(prompt_token_ids) + num_cache_tokens += output.get("num_cache_tokens") or 0 + if output.get("num_image_tokens"): + num_generated_tokens += output.get("num_image_tokens") + num_image_tokens += output.get("num_image_tokens") num_reasoning_tokens += output.get("reasoning_token_num", 0) @@ -622,7 +636,10 @@ def request_output_to_completion_response( prompt_tokens=num_prompt_tokens, completion_tokens=num_generated_tokens, total_tokens=num_prompt_tokens + num_generated_tokens, - completion_tokens_details=CompletionTokenUsageInfo(reasoning_tokens=num_reasoning_tokens), + prompt_tokens_details=PromptTokenUsageInfo(cached_tokens=num_cache_tokens), + completion_tokens_details=CompletionTokenUsageInfo( + reasoning_tokens=num_reasoning_tokens, image_tokens=num_image_tokens + ), ) del request diff --git a/fastdeploy/entrypoints/openai/usage_calculator.py b/fastdeploy/entrypoints/openai/usage_calculator.py new file mode 100644 index 00000000000..4cf22ded0d6 --- /dev/null +++ b/fastdeploy/entrypoints/openai/usage_calculator.py @@ -0,0 +1,33 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import numpy as np + + +def count_tokens(tokens): + """ + Count the number of tokens in a nested list or array structure. + """ + count = 0 + stack = [tokens] + while stack: + current = stack.pop() + if isinstance(current, (list, tuple, np.ndarray)): + for item in reversed(current): + stack.append(item) + else: + count += 1 + return count diff --git a/tests/ce/server/test_logprobs.py b/tests/ce/server/test_logprobs.py index bba2e0f5bd2..2431fe2e3ae 100644 --- a/tests/ce/server/test_logprobs.py +++ b/tests/ce/server/test_logprobs.py @@ -32,13 +32,9 @@ def test_unstream_with_logprobs(): "bytes": [231, 137, 155, 233, 161, 191], "top_logprobs": None, } - assert resp_json["usage"] == { - "prompt_tokens": 22, - "total_tokens": 25, - "completion_tokens": 3, - "prompt_tokens_details": {"cached_tokens": 0}, - "completion_tokens_details": {"reasoning_tokens": 0}, - } + assert resp_json["usage"]["prompt_tokens"] == 22 + assert resp_json["usage"]["completion_tokens"] == 3 + assert resp_json["usage"]["total_tokens"] == 25 def test_unstream_without_logprobs(): @@ -65,13 +61,9 @@ def test_unstream_without_logprobs(): # 校验返回内容与 logprobs 字段 assert resp_json["choices"][0]["message"]["content"] == "牛顿的" assert resp_json["choices"][0]["logprobs"] is None - assert resp_json["usage"] == { - "prompt_tokens": 22, - "total_tokens": 25, - "completion_tokens": 3, - "prompt_tokens_details": {"cached_tokens": 0}, - "completion_tokens_details": {"reasoning_tokens": 0}, - } + assert resp_json["usage"]["prompt_tokens"] == 22 + assert resp_json["usage"]["completion_tokens"] == 3 + assert resp_json["usage"]["total_tokens"] == 25 def test_stream_with_logprobs(): diff --git a/tests/entrypoints/openai/test_max_streaming_tokens.py b/tests/entrypoints/openai/test_max_streaming_tokens.py index fe551506898..3857f603e4b 100644 --- a/tests/entrypoints/openai/test_max_streaming_tokens.py +++ b/tests/entrypoints/openai/test_max_streaming_tokens.py @@ -388,6 +388,7 @@ async def test_create_chat_completion_choice(self): "reasoning_content": "Normal reasoning", "tool_call": None, "num_cached_tokens": 3, + "num_image_tokens": 2, "raw_prediction": "raw_answer_0", }, "finished": True, @@ -403,6 +404,7 @@ async def test_create_chat_completion_choice(self): "tool_calls": None, "raw_prediction": "raw_answer_0", "num_cached_tokens": 3, + "num_image_tokens": 2, "finish_reason": "stop", }, }, @@ -415,6 +417,7 @@ async def test_create_chat_completion_choice(self): "reasoning_content": None, "tool_call": None, "num_cached_tokens": 0, + "num_image_tokens": 0, "raw_prediction": None, }, "finished": True, @@ -430,6 +433,7 @@ async def test_create_chat_completion_choice(self): "tool_calls": None, "raw_prediction": None, "num_cached_tokens": 0, + "num_image_tokens": 0, "finish_reason": "stop", }, }, @@ -442,6 +446,7 @@ async def test_create_chat_completion_choice(self): mock_response_processor.enable_multimodal_content.return_value = False completion_token_ids = [[], []] num_cached_tokens = [0, 0] + num_image_tokens = [0, 0] for idx, case in enumerate(test_cases): actual_choice = await self.chat_serving._create_chat_completion_choice( @@ -453,6 +458,7 @@ async def test_create_chat_completion_choice(self): prompt_tokens=prompt_tokens, completion_token_ids=completion_token_ids[idx], num_cached_tokens=num_cached_tokens, + num_image_tokens=num_image_tokens, logprob_contents=logprob_contents, response_processor=mock_response_processor, ) @@ -468,6 +474,7 @@ async def test_create_chat_completion_choice(self): self.assertEqual(actual_choice.message.completion_token_ids, completion_token_ids[idx]) self.assertEqual(num_cached_tokens[expected["index"]], expected["num_cached_tokens"]) + self.assertEqual(num_image_tokens[expected["index"]], expected["num_image_tokens"]) self.assertEqual(actual_choice.finish_reason, expected["finish_reason"]) assert actual_choice.logprobs is not None diff --git a/tests/entrypoints/openai/test_usage_calculator.py b/tests/entrypoints/openai/test_usage_calculator.py new file mode 100644 index 00000000000..3a22b4f372e --- /dev/null +++ b/tests/entrypoints/openai/test_usage_calculator.py @@ -0,0 +1,158 @@ +""" +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License" +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +""" + +import numpy as np + +from fastdeploy.entrypoints.openai.usage_calculator import count_tokens + + +class TestCountTokens: + """Test cases for count_tokens function""" + + def test_empty_list(self): + """Test counting tokens in an empty list""" + tokens = [] + result = count_tokens(tokens) + assert result == 0 + + def test_flat_list_of_integers(self): + """Test counting tokens in a flat list of integers""" + tokens = [1, 2, 3, 4, 5] + result = count_tokens(tokens) + assert result == 5 + + def test_flat_list_of_strings(self): + """Test counting tokens in a flat list of strings""" + tokens = ["hello", "world", "test"] + result = count_tokens(tokens) + assert result == 3 + + def test_flat_numpy_array(self): + """Test counting tokens in a flat numpy array""" + tokens = np.array([1, 2, 3, 4, 5]) + result = count_tokens(tokens) + assert result == 5 + + def test_nested_list_one_level(self): + """Test counting tokens in a nested list with one level of nesting""" + tokens = [[1, 2], [3, 4], [5]] + result = count_tokens(tokens) + assert result == 5 + + def test_nested_list_multiple_levels(self): + """Test counting tokens in a deeply nested list""" + tokens = [[1, [2, 3]], [4, [5, [6]]], 7] + result = count_tokens(tokens) + assert result == 7 + + def test_nested_tuple(self): + """Test counting tokens in nested tuples""" + tokens = ((1, 2), (3, (4, 5)), 6) + result = count_tokens(tokens) + assert result == 6 + + def test_mixed_nested_structures(self): + """Test counting tokens in mixed nested structures (list, tuple, numpy array)""" + tokens = [1, (2, 3), np.array([4, 5]), [6, [7, 8]]] + result = count_tokens(tokens) + assert result == 8 + + def test_single_element_list(self): + """Test counting tokens in a list with single element""" + tokens = [42] + result = count_tokens(tokens) + assert result == 1 + + def test_single_element_tuple(self): + """Test counting tokens in a tuple with single element""" + tokens = (42,) + result = count_tokens(tokens) + assert result == 1 + + def test_single_element_numpy_array(self): + """Test counting tokens in a numpy array with single element""" + tokens = np.array([42]) + result = count_tokens(tokens) + assert result == 1 + + def test_nested_empty_lists(self): + """Test counting tokens in nested empty lists""" + tokens = [[], [[]], [[[]]]] + result = count_tokens(tokens) + assert result == 0 + + def test_complex_mixed_structure(self): + """Test counting tokens in a complex mixed structure""" + tokens = [ + 1, + [2, 3, (4, np.array([5, 6]))], + [7, [8, 9, [10]]], + (11, [12, 13]), + np.array([14, 15]), # Note: numpy arrays can't contain lists directly + ] + # Flatten the structure manually for expected count + result = count_tokens(tokens) + assert result == 15 + + def test_large_flat_list(self): + """Test counting tokens in a large flat list""" + tokens = list(range(1000)) + result = count_tokens(tokens) + assert result == 1000 + + def test_none_values(self): + """Test counting tokens when list contains None values""" + tokens = [1, None, 2, [None, 3], None] + result = count_tokens(tokens) + assert result == 6 + + def test_boolean_values(self): + """Test counting tokens with boolean values""" + tokens = [True, False, [True, False]] + result = count_tokens(tokens) + assert result == 4 + + def test_float_values(self): + """Test counting tokens with float values""" + tokens = [1.5, 2.7, [3.14, 4.2]] + result = count_tokens(tokens) + assert result == 4 + + def test_mixed_data_types(self): + """Test counting tokens with mixed data types""" + tokens = [1, "hello", 2.5, True, None, [1, "world"]] + result = count_tokens(tokens) + assert result == 7 + + def test_deeply_nested_structure(self): + """Test counting tokens in a very deeply nested structure""" + tokens = 1 + for _ in range(100): + tokens = [tokens] + result = count_tokens(tokens) + assert result == 1 + + def test_numpy_array_2d(self): + """Test counting tokens in a 2D numpy array""" + tokens = np.array([[1, 2], [3, 4], [5, 6]]) + result = count_tokens(tokens) + assert result == 6 + + def test_numpy_array_3d(self): + """Test counting tokens in a 3D numpy array""" + tokens = np.array([[[1, 2], [3, 4]], [[5, 6], [7, 8]]]) + result = count_tokens(tokens) + assert result == 8