diff --git a/3rdparty/tvm b/3rdparty/tvm index 9c894f78fd..7752c9221c 160000 --- a/3rdparty/tvm +++ b/3rdparty/tvm @@ -1 +1 @@ -Subproject commit 9c894f78fdef156263ced19eed67e79203ca4a11 +Subproject commit 7752c9221c768617af01711f8ad155e0a1cd409e diff --git a/3rdparty/xgrammar b/3rdparty/xgrammar index d4f57c440f..dbf200ecde 160000 --- a/3rdparty/xgrammar +++ b/3rdparty/xgrammar @@ -1 +1 @@ -Subproject commit d4f57c440f3da8e7330a1e5d50bba9c31f9433ea +Subproject commit dbf200ecde5dd5467c8320076ee60b1e248b23e0 diff --git a/CMakeLists.txt b/CMakeLists.txt index a010a05192..99926be832 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -71,7 +71,7 @@ add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL) set(XGRAMMAR_PATH 3rdparty/xgrammar) tvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc) tvm_file_glob(GLOB_RECURSE XGRAMMAR_SRCS ${XGRAMMAR_PATH}/cpp/*.cc) -list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/pybind/.*\\.cc") +list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/nanobind/.*\\.cc") list(APPEND MLC_LLM_SRCS ${XGRAMMAR_SRCS}) add_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS}) diff --git a/cpp/serve/config.cc b/cpp/serve/config.cc index f7e71e72c9..3be0b815bd 100644 --- a/cpp/serve/config.cc +++ b/cpp/serve/config.cc @@ -42,13 +42,43 @@ Result ResponseFormat::FromJSON(const picojson::object& config) ResponseFormat res; res.type = json::LookupOrDefault(config, "type", "text"); + if (res.type != "text" && res.type != "function" && res.type != "json_object" && + res.type != "json_schema" && res.type != "structural_tag") { + return TResult::Error("Uknonwn response_format type " + res.type); + } + std::optional schema = json::LookupOptional(config, "schema"); if (schema.has_value()) { res.schema = schema.value(); } - if (res.type != "text" && res.type != "function" && res.type != "json_object") { - return TResult::Error("Uknonwn response_format type " + res.type); + if (auto tags_obj = json::LookupOptional(config, "tags")) { + auto tags = Array>(); + for (auto tag_obj : tags_obj.value()) { + Array tag = Array(); + std::optional begin = + json::LookupOptional(tag_obj.get(), "begin"); + std::optional schema = + json::LookupOptional(tag_obj.get(), "schema"); + std::optional end = + json::LookupOptional(tag_obj.get(), "end"); + if (!(begin.has_value() && schema.has_value() && end.has_value())) { + return TResult::Error("Miss tag attribute."); + } + tag.push_back(begin.value()); + tag.push_back(schema.value()); + tag.push_back(end.value()); + tags.push_back(std::move(tag)); + } + res.tags = tags; + } + + if (auto triggers_obj = json::LookupOptional(config, "triggers")) { + auto triggers = Array(); + for (auto trigger : triggers_obj.value()) { + triggers.push_back(trigger.get()); + } + res.triggers = triggers; } return TResult::Ok(res); @@ -60,6 +90,24 @@ picojson::object ResponseFormat::AsJSON() const { if (schema.defined()) { config["schema"] = picojson::value(schema.value().operator std::string()); } + if (tags.defined()) { + picojson::array tags_obj = picojson::array(); + for (auto tag : tags.value()) { + picojson::array tag_obj = picojson::array(); + tag_obj.emplace_back(tag[0]); + tag_obj.emplace_back(tag[1]); + tag_obj.emplace_back(tag[2]); + tags_obj.emplace_back(tag_obj); + } + config["tags"] = picojson::value(tags_obj); + } + if (triggers.defined()) { + picojson::array trigger_obj = picojson::array(); + for (std::string trigger : triggers.value()) { + trigger_obj.emplace_back(trigger); + } + config["triggers"] = picojson::value(trigger_obj); + } return config; } @@ -1073,4 +1121,4 @@ Result ModelsUseKVCache(const std::vector& model_configs } // namespace serve } // namespace llm -} // namespace mlc +} // namespace mlc \ No newline at end of file diff --git a/cpp/serve/config.h b/cpp/serve/config.h index 9da3ba2517..c73b7d0b50 100644 --- a/cpp/serve/config.h +++ b/cpp/serve/config.h @@ -28,6 +28,8 @@ using namespace tvm::runtime; struct ResponseFormat { String type = "text"; Optional schema = NullOpt; + Optional>> tags = NullOpt; + Optional> triggers = NullOpt; /*! * \brief Create debug config from JSON. * \param config_json The json string for generation config diff --git a/cpp/serve/engine.cc b/cpp/serve/engine.cc index 2f09219392..c2b967ecf3 100644 --- a/cpp/serve/engine.cc +++ b/cpp/serve/engine.cc @@ -35,6 +35,7 @@ #include "request.h" #include "request_state.h" #include "sampler/sampler.h" +#include "xgrammar/tokenizer_info.h" namespace mlc { namespace llm { @@ -64,6 +65,9 @@ inline std::optional GetTokenizerInfo(const picojson::object& mod if (tokenizer_info_obj.count("strip_space_in_decode")) { info->strip_space_in_decode = tokenizer_info_obj.at("strip_space_in_decode").get(); } + if (model_config.count("vocab_size")) { + info->vocab_size = model_config.at("vocab_size").get(); + } return TokenizerInfo(info); } @@ -463,9 +467,17 @@ class EngineImpl : public Engine { ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()}); } // - Initialize tokenizer and grammar - n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0])); + + std::optional info = GetTokenizerInfo(model_configs[0]); + n->tokenizer_ = Tokenizer::FromPath(engine_config->model, info); n->token_table_ = n->tokenizer_->PostProcessedTokenTable(); - n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_); + int64_t vocab_size = n->tokenizer_->GetVocabSize(); + if (info.has_value() && info.value()->vocab_size != 0) { + vocab_size = info.value()->vocab_size; + } + n->grammar_compiler_ = xgrammar::GrammarCompiler( + xgrammar::TokenizerInfo(n->token_table_, xgrammar::VocabType::RAW, vocab_size)); + // - Create the logit processor and sampler, and // the DraftTokenWorkspaceManager for speculative decoding. int max_num_tokens = engine_config->max_num_sequence; @@ -975,13 +987,22 @@ class EngineImpl : public Engine { * is not JSON, return std::nullopt. */ std::optional GetGrammarFromResponseFormat( const ResponseFormat& response_format) { - if (response_format.type != "json_object") { + if (response_format.type == "text") { return std::nullopt; - } else if (!response_format.schema) { - return cached_grammar_compiler_.GetCompiledGrammarForJSON(); - } else { - return cached_grammar_compiler_.GetCompiledGrammarForJSONSchema( - response_format.schema.value()); + } else if (response_format.type == "json_object") { + return grammar_compiler_.CompileBuiltinJSONGrammar(); + } else if (response_format.type == "json_schema") { + return grammar_compiler_.CompileJSONSchema(response_format.schema.value()); + } else if (response_format.type == "structural_tag") { + std::vector tags; + std::vector triggers; + for (auto tag : response_format.tags.value()) { + tags.emplace_back(xgrammar::StructuralTagItem{tag[0], tag[1], tag[2]}); + } + for (auto trigger : response_format.triggers.value()) { + triggers.emplace_back(trigger); + } + return grammar_compiler_.CompileStructuralTag(std::move(tags), std::move(triggers)); } } @@ -992,8 +1013,8 @@ class EngineImpl : public Engine { // internal tokenizer Tokenizer tokenizer_; std::vector token_table_; - // Cached grammar compiler for grammar matching. - xgrammar::CachedGrammarCompiler cached_grammar_compiler_; + // Grammar compiler for grammar matching. + xgrammar::GrammarCompiler grammar_compiler_; // Models Array models_; // Device that the models run on. diff --git a/cpp/serve/request_state.cc b/cpp/serve/request_state.cc index 17e02ee85b..4771f14d3b 100644 --- a/cpp/serve/request_state.cc +++ b/cpp/serve/request_state.cc @@ -24,7 +24,7 @@ RequestModelState::RequestModelState( if (compiled_grammar.has_value()) { // TODO(yixin): set rollback limit to a configurable value. n->grammar_matcher = - xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, std::nullopt, 10); + xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, 10); } n->request = std::move(request); @@ -44,7 +44,7 @@ bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_matcher.h void RequestModelStateNode::GetNextTokenBitmask(DLTensor* bitmask) { ICHECK(grammar_matcher.has_value()); - grammar_matcher->GetNextTokenBitmask(bitmask); + grammar_matcher->FillNextTokenBitmask(bitmask); } void RequestModelStateNode::CommitToken(SampleResult sampled_token) { diff --git a/cpp/tokenizers/tokenizers.cc b/cpp/tokenizers/tokenizers.cc index 13ae547d72..2c0d1f4e96 100644 --- a/cpp/tokenizers/tokenizers.cc +++ b/cpp/tokenizers/tokenizers.cc @@ -30,6 +30,7 @@ String TokenizerInfoNode::AsJSONString() const { obj["token_postproc_method"] = picojson::value(token_postproc_method); obj["prepend_space_in_encode"] = picojson::value(prepend_space_in_encode); obj["strip_space_in_decode"] = picojson::value(strip_space_in_decode); + obj["vocab_size"] = picojson::value(vocab_size); return picojson::value(obj).serialize(false); } @@ -54,6 +55,10 @@ TokenizerInfo TokenizerInfo::FromJSONString(String json_string) { ICHECK(obj.at("strip_space_in_decode").is()); n->strip_space_in_decode = obj.at("strip_space_in_decode").get(); } + if (obj.count("vocab_size")) { + ICHECK(obj.at("vocab_size").is()); + n->vocab_size = obj.at("vocab_size").get(); + } return TokenizerInfo(n); } diff --git a/cpp/tokenizers/tokenizers.h b/cpp/tokenizers/tokenizers.h index 2b1847f524..bca2a0e50c 100644 --- a/cpp/tokenizers/tokenizers.h +++ b/cpp/tokenizers/tokenizers.h @@ -43,6 +43,9 @@ class TokenizerInfoNode : public Object { bool prepend_space_in_encode = false; /*! \brief Whether to strip the first space during decoding. */ bool strip_space_in_decode = false; + /*! \brief The vocab_size in config.json (length of logits).This may be bigger than the vocabulary + * size. The value will be 0 if not set.*/ + int64_t vocab_size = 0; String AsJSONString() const; diff --git a/python/mlc_llm/protocol/conversation_protocol.py b/python/mlc_llm/protocol/conversation_protocol.py index 71738efeef..14adcb8dd1 100644 --- a/python/mlc_llm/protocol/conversation_protocol.py +++ b/python/mlc_llm/protocol/conversation_protocol.py @@ -81,6 +81,8 @@ class Conversation(BaseModel): function_string: str = "" # whether using function calling or not, helps check for output message format in API call use_function_calling: bool = False + # Tool function call format mode + _tool_call_format: str = "json" def __init__(self, role_templates: Optional[Dict[str, str]] = None, **kwargs): # Defaults templates which would be overridden by model specific templates @@ -124,6 +126,8 @@ def as_prompt(self, config=None) -> List[Any]: from ..serve import data # pylint: disable=import-outside-toplevel # - Get the system message. + if self.use_function_calling: + self.set_tool_call_format_in_system_message() system_msg = self.system_template.replace( MessagePlaceholders.SYSTEM.value, self.system_message ) @@ -195,6 +199,51 @@ def as_prompt(self, config=None) -> List[Any]: return prompt + def set_tool_call_format_in_system_message(self): + """Add tool function information and call format to the system message.""" + if self._tool_call_format == "json": + tool_call_instruct = ( + "Tool Instructions:" + "You have access to the following tool functions:" + f"{MessagePlaceholders.FUNCTION.value}" + "If a you choose to call a function, you should ONLY reply in the following format:" + '`{"name": func_name, "parameters": parameters(JSON dict)}`' + "Here is an example," + '`{"name": "get_time", "parameters": {"location": "Pittsburgh"}}}}`' + "Reminder:" + "- Function calls MUST follow the specified format" + "- Required parameters MUST be specified" + "- You should not repeat or miss the call" + "- You should response with at least one function calling" + ) + self.system_message += tool_call_instruct + elif self._tool_call_format == "xml": + tool_call_instruct = ( + "Tool Instructions:" + "You have access to the following tool functions:" + f"{MessagePlaceholders.FUNCTION.value}" + "If a you choose to call a function, you should ONLY reply in the following format:" + "`{parameters(JSON dict)}`" + "Here is an example," + '`{"location": "Pittsburgh"}`' + "Reminder:" + "- Function calls MUST follow the specified format" + "- Required parameters MUST be specified" + "- You should not repeat or miss the call" + ) + self.system_message += tool_call_instruct + elif self._tool_call_format == "python": + tool_call_instruct = ( + "Tool Instructions:" + "- You have access to the following tool functions:" + f"{MessagePlaceholders.FUNCTION.value}" + "- Required parameters MUST be specified" + "- You should not repeat or miss the call" + ) + self.system_message += tool_call_instruct + else: + raise ValueError("Unknown tool calling format.") + def _get_url_from_item(item: Dict) -> str: image_url: str diff --git a/python/mlc_llm/protocol/openai_api_protocol.py b/python/mlc_llm/protocol/openai_api_protocol.py index cb2e1f2852..d11319acd2 100644 --- a/python/mlc_llm/protocol/openai_api_protocol.py +++ b/python/mlc_llm/protocol/openai_api_protocol.py @@ -86,12 +86,46 @@ class ModelResponse(BaseModel): class RequestResponseFormat(BaseModel): - type: Literal["text", "json_object"] = "text" - json_schema: Optional[str] = Field(default=None, alias="schema") + type: Literal["text", "json_object", "json_schema", "structural_tag"] = "text" """This field is named json_schema instead of schema because BaseModel defines a method called schema. During construction of RequestResponseFormat, key "schema" still should be used: - `RequestResponseFormat(type="json_object", schema="{}")` + `RequestResponseFormat(type="json_schema", schema="{}")` """ + json_schema: Optional[str] = Field(default=None, alias="schema") + + """These field are only used for type="structural_tag".""" + tags: Optional[List[Dict[str, str]]] = None + triggers: Optional[List[str]] = None + + @model_validator(mode="after") + def check_request_response_format(self) -> "RequestResponseFormat": + """Check if the RequestResponseFormat is valid.""" + if self.type in ["text", "json_object"]: + if self.json_schema is not None: + raise Warning("'json_schema' should be set in 'json_schema' type.") + if self.tags is not None or self.triggers is not None: + raise Warning( + "'tags' and 'triggers' attributes should be used when type='structural_tag'" + ) + elif self.type == "json_schema": + if self.json_schema is None: + raise ValueError("'json_schema' should be set in 'json_schema' type.") + if self.tags is not None or self.triggers is not None: + raise Warning( + "'tags' and 'triggers' attributes should be used when type='structural_tag'" + ) + elif self.type == "structural_tag": + if self.tags is None or self.triggers is None: + raise ValueError("structural_tag type must contain keys 'tags' and 'triggers'.") + for tag in self.tags: # pylint: disable=not-an-iterable + if set(tag.keys()) != {"begin", "schema", "end"}: + raise ValueError( + "Each tag must contain exactly 'begin', 'schema' and 'end' keys." + f"Got keys: {list(tag.keys())}." + ) + if self.json_schema is not None: + raise Warning("'json_schema' should be set in 'json_schema' type.") + return self class CompletionRequest(BaseModel): @@ -181,6 +215,7 @@ class ChatFunction(BaseModel): description: Optional[str] = None name: str parameters: Dict + strict: bool = True class ChatTool(BaseModel): @@ -318,12 +353,10 @@ def check_function_call_usage(self, conv_template: Conversation) -> None: """Check if function calling is used and update the conversation template. Return error message if invalid request format for function calling. """ - # return if no tools are provided or tool_choice is set to none if self.tools is None or (isinstance(self.tool_choice, str) and self.tool_choice == "none"): conv_template.use_function_calling = False return - # select the tool based on the tool_choice if specified if isinstance(self.tool_choice, dict): if self.tool_choice["type"] != "function": # pylint: disable=unsubscriptable-object diff --git a/python/mlc_llm/serve/config.py b/python/mlc_llm/serve/config.py index 9b82de8350..ff8a7e69d7 100644 --- a/python/mlc_llm/serve/config.py +++ b/python/mlc_llm/serve/config.py @@ -130,6 +130,17 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes "hybrid" means the hybrid prefill or split-fuse, so that decode step will be converted into prefill. + tool_call_format : Literal["json", "xml", "python"] + The tool function call foramt. + "json" means model will call tool function in json style format + '{"name": func_name, "parameters": parameters(JSON dict)}', + e.g. '{"name": "get_time", "parameters": {"location": "Pittsburgh"}}'. + "xml" means model will call tool function in xml style format + '{parameters(JSON dict)}', + e.g. '{"location": "Pittsburgh"}'. + "python" means model will call tool function in python-style format, + e.g. 'wolfram_alpha.call(query="solve x^3 - 4x^2 + 6x - 24 = 0")'. + verbose : bool A boolean indicating whether to print logging info in engine. """ @@ -157,6 +168,7 @@ class EngineConfig: # pylint: disable=too-many-instance-attributes prefix_cache_mode: Literal["disable", "radix"] = "radix" prefix_cache_max_num_recycling_seqs: Optional[int] = None prefill_mode: Literal["chunked", "hybrid"] = "hybrid" + tool_call_format: Literal["json", "xml", "python"] = "json" verbose: bool = True def asjson(self) -> str: diff --git a/python/mlc_llm/serve/engine.py b/python/mlc_llm/serve/engine.py index 3d9d181b1f..57841dd48f 100644 --- a/python/mlc_llm/serve/engine.py +++ b/python/mlc_llm/serve/engine.py @@ -1056,7 +1056,7 @@ async def _chat_completion( # pylint: disable=too-many-arguments,too-many-local assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( - output_texts, finish_reasons + output_texts, finish_reasons, self.engine_config.tool_call_format ) return engine_base.wrap_chat_completion_response( request_id=request_id, @@ -1207,6 +1207,12 @@ async def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ + request.response_format = engine_base.set_structural_tag_from_tools( + request.tools, + request.response_format, + request.tool_choice, + self.engine_config.tool_call_format, + ) ( prompts, generation_cfg, @@ -1617,7 +1623,7 @@ def _chat_completion( # pylint: disable=too-many-arguments,too-many-locals assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( - output_texts, finish_reasons + output_texts, finish_reasons, self.engine_config.tool_call_format ) return engine_base.wrap_chat_completion_response( request_id=request_id, @@ -1764,6 +1770,12 @@ def _handle_chat_completion( e : BadRequestError BadRequestError is raised when the request is invalid. """ + request.response_format = engine_base.set_structural_tag_from_tools( + request.tools, + request.response_format, + request.tool_choice, + self.engine_config.tool_call_format, + ) ( prompts, generation_cfg, diff --git a/python/mlc_llm/serve/engine_base.py b/python/mlc_llm/serve/engine_base.py index 1d5303e412..61118f8283 100644 --- a/python/mlc_llm/serve/engine_base.py +++ b/python/mlc_llm/serve/engine_base.py @@ -130,7 +130,9 @@ def _convert_model_info(model: ModelInfo) -> Tuple[str, str]: if conversation is None: conversation = mlc_chat_config.conv_template - + conversation._tool_call_format = ( # pylint: disable=protected-access + engine_config.tool_call_format + ) if model.model_lib is not None: # do model lib search if the model lib is provided # error out if file not found @@ -644,6 +646,7 @@ def __init__( # pylint: disable=too-many-arguments,too-many-locals engine_config.mode = mode self._ffi["reload"](engine_config.asjson()) self.engine_config = EngineConfig.from_json(self._ffi["get_complete_engine_config"]()) + self.engine_config.tool_call_format = engine_config.tool_call_format self.max_input_sequence_length = min( self.engine_config.max_single_sequence_length, self.engine_config.max_total_sequence_length, @@ -1146,36 +1149,172 @@ def create_completion_suffix_response( return response -def convert_function_str_to_json(stringified_calls: str) -> List[Union[Dict, None]]: +def set_structural_tag_from_tools( # pylint: disable=too-many-branches,too-many-boolean-expressions + tools: Optional[List[openai_api_protocol.ChatTool]], + response_format: Optional[openai_api_protocol.RequestResponseFormat], + tool_choice: Optional[Union[Literal["none", "auto"], Dict]], + tool_call_format: str, +): + """Add the corresponding structural tag to the response format according to + the tools to ensure valid function calling. Only set in strict mode of the tool. + Return the updated response format. + """ + if tools is None or (isinstance(tool_choice, str) and tool_choice == "none"): + return response_format + + if response_format is None or response_format.type == "text": + response_format = openai_api_protocol.RequestResponseFormat.model_validate( + {"type": "structural_tag", "tags": [], "triggers": []} + ) + elif response_format.type == "json_object": + response_format.tags = [] + response_format.triggers = [] + + if tool_call_format == "json": + begin_format = '{{"name": "{func_name}", "parameters":' + end = "}" + for tool in tools: + if tool.function.strict and ( + tool_choice is None + or (isinstance(tool_choice, str) and tool_choice == "auto") + or ( + isinstance(tool_choice, dict) + and tool.function.name == tool_choice["function"]["name"] + ) + ): + schema = { + "properties": tool.function.parameters["properties"], + "required": tool.function.parameters["required"], + "type": tool.function.parameters["type"], + } + response_format.tags.append( + { + "begin": begin_format.format(func_name=tool.function.name), + "schema": json.dumps(schema), + "end": end, + } + ) + response_format.triggers.append('{"name":') + + elif tool_call_format == "xml": + begin_format = "" + end = "" + for tool in tools: + if tool.function.strict and ( + tool_choice is None + or (isinstance(tool_choice, str) and tool_choice == "auto") + or ( + isinstance(tool_choice, dict) + and tool.function.name == tool_choice["function"]["name"] + ) + ): + schema = { + "properties": tool.function.parameters["properties"], + "required": tool.function.parameters["required"], + "type": tool.function.parameters["type"], + } + response_format.tags.append( + { + "begin": begin_format.format(func_name=tool.function.name), + "schema": json.dumps(schema), + "end": end, + } + ) + response_format.triggers.append(" List[Union[Dict, None]]: """Convert a (possibly list) of function call string to a list of json objects. Return None for invalid function call string.""" + function_calls_json = [] + + if tool_call_format == "json": + # tool calling in format `{"name": func_name, "parameters": parameters(JSON dict)}` + start = 0 + while True: + index = stringified_calls.find('{"name":', start) + if index == -1: + break + try: + decoder = json.JSONDecoder() + result, end_index = decoder.raw_decode(stringified_calls, index) + except: # pylint: disable=bare-except + start = index + 1 + continue + start = end_index + if not isinstance(result, dict) or "name" not in result or "parameters" not in result: + continue + function_calls_json.append({"name": result["name"], "arguments": result["parameters"]}) + + elif tool_call_format == "xml": + # tool calling in format `{PARA}` + start = 0 + while True: + begin_start = stringified_calls.find("", begin_start) + if begin_end == -1: + break + end_start = stringified_calls.find("", begin_end) + if end_start == -1: + break + start = end_start + len("") + + func_name = stringified_calls[begin_start + len(" Tuple[bool, List[List[openai_api_protocol.ChatToolCall]]]: """Process the potential function call results outputted by model, - according to the finish reasons. + according to the finish reasons and the tool calling format. Return whether the output has function call, and the list of tool calls. """ n = len(output_texts) @@ -1184,9 +1323,9 @@ def process_function_call_output( if use_function_calling: for i, output_text in enumerate(output_texts): try: - fn_json_list = convert_function_str_to_json(output_text) + fn_json_list = convert_function_str_to_json(output_text, tool_call_format) except (SyntaxError, ValueError): - output_text = "Got an invalid function call output from model" + output_text += "[engine info] Got an invalid function call output from model" finish_reasons[i] = "error" else: tool_calls_list[i] = [ @@ -1200,7 +1339,9 @@ def process_function_call_output( if fn_json_obj is not None ] if len(tool_calls_list[i]) == 0: - output_texts[i] = "Got an invalid function call output from model" + output_texts[ + i + ] += "[engine info] Got an invalid function call output from model" finish_reasons[i] = "error" else: finish_reasons[i] = "tool_calls" @@ -1228,7 +1369,7 @@ def wrap_chat_completion_response( # pylint: disable=too-many-arguments openai_api_protocol.ChatCompletionMessage(role="assistant", content=output_text) if not use_function_calling or finish_reason == "error" else openai_api_protocol.ChatCompletionMessage( - role="assistant", tool_calls=tool_calls + role="assistant", tool_calls=tool_calls, content=output_text ) ), logprobs=( diff --git a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py index 18a415e413..5c5637e021 100644 --- a/python/mlc_llm/serve/entrypoints/openai_entrypoints.py +++ b/python/mlc_llm/serve/entrypoints/openai_entrypoints.py @@ -219,10 +219,9 @@ async def completion_stream_generator() -> AsyncGenerator[str, None]: if choice.logprobs is not None: assert logprob_results is not None logprob_results[choice.index] += choice.logprobs.content - assert all(finish_reason is not None for finish_reason in finish_reasons) use_function_calling, tool_calls_list = engine_base.process_function_call_output( - output_texts, finish_reasons + output_texts, finish_reasons, async_engine.engine_config.tool_call_format ) return engine_base.wrap_chat_completion_response( diff --git a/tests/python/serve/server/test_server_structural_tag.py b/tests/python/serve/server/test_server_structural_tag.py new file mode 100644 index 0000000000..f34df96d31 --- /dev/null +++ b/tests/python/serve/server/test_server_structural_tag.py @@ -0,0 +1,432 @@ +# pylint: disable=line-too-long +""" +Test script for structural tag in chat completion. To run this script, use the following command: +- start a new shell session, run + mlc_llm serve --model "YOUR_MODEL" (e.g. ./dist/Llama-2-7b-chat-hf-q0f16-MLC) +- start another shell session, run this file + MLC_SERVE_MODEL="YOUR_MODEL" python tests/python/serve/server/test_server_structural_tag.py +""" + +# pylint: disable=missing-function-docstring,too-many-arguments,too-many-locals,too-many-branches +import json +import os +import re +from typing import Any, Dict, List, Optional + +import pytest +import requests + +OPENAI_V1_CHAT_COMPLETION_URL = "http://127.0.0.1:8000/v1/chat/completions" + + +def check_openai_nonstream_response( + response: Dict, + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: List[str], + completion_tokens: Optional[int] = None, +): + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + assert choice["finish_reason"] in finish_reason + + find_format_start = set() + beg_tag_start = set() + message = choice["message"]["content"] + print("Outputs:\n-----------") + print(message, flush=True) + pattern1 = r"(.*?)\|(.*?)\|End<---(.*?)>" + pattern2 = r"(.*?)\|(.*?)\|End<---(.*?)>" + # check format + for match in re.finditer(pattern1, message): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "CALL", match.group(2)) + for match in re.finditer(pattern2, message): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "call", match.group(2)) + for match in re.finditer(r"", message): + beg_tag_start.add(match.start()) + for match in re.finditer(r"", message): + beg_tag_start.add(match.start()) + assert find_format_start == beg_tag_start + + usage = response["usage"] + assert isinstance(usage, dict) + assert usage["total_tokens"] == usage["prompt_tokens"] + usage["completion_tokens"] + assert usage["prompt_tokens"] > 0 + + if completion_tokens is not None: + assert usage["completion_tokens"] == completion_tokens + + +def check_openai_stream_response( + responses: List[Dict], + *, + model: str, + object_str: str, + num_choices: int, + finish_reason: str, + echo_prompt: Optional[str] = None, + suffix: Optional[str] = None, + stop: Optional[List[str]] = None, + require_substr: Optional[List[str]] = None, +): + assert len(responses) > 0 + + finished = [False for _ in range(num_choices)] + outputs = ["" for _ in range(num_choices)] + for response in responses: + assert response["model"] == model + assert response["object"] == object_str + + choices = response["choices"] + assert isinstance(choices, list) + assert len(choices) == num_choices + for idx, choice in enumerate(choices): + assert choice["index"] == idx + + delta = choice["delta"] + assert delta["role"] == "assistant" + assert isinstance(delta["content"], str) + outputs[idx] += delta["content"] + + if finished[idx]: + assert choice["finish_reason"] == finish_reason + elif choice["finish_reason"] is not None: + assert choice["finish_reason"] == finish_reason + finished[idx] = True + + for output in outputs: + if echo_prompt is not None: + assert output.startswith(echo_prompt) + if suffix is not None: + assert output.endswith(suffix) + if stop is not None: + for stop_str in stop: + assert stop_str not in output + if require_substr is not None: + for substr in require_substr: + assert substr in output + find_format_start = set() + beg_tag_start = set() + print("Outputs:\n-----------") + print(output, flush=True) + pattern1 = r"(.*?)\|(.*?)\|End<---(.*?)>" + pattern2 = r"(.*?)\|(.*?)\|End<---(.*?)>" + # check format + for match in re.finditer(pattern1, output): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "CALL", match.group(2)) + for match in re.finditer(pattern2, output): + find_format_start.add(match.start()) + check_format(match.group(1), match.group(3), "call", match.group(2)) + for match in re.finditer(r"", output): + beg_tag_start.add(match.start()) + for match in re.finditer(r"", output): + beg_tag_start.add(match.start()) + assert find_format_start == beg_tag_start + + +def check_format(name_beg: str, name_end: str, beg_tag: str, schema: str): + try: + paras: Dict[str, Any] = json.loads(schema) + except json.JSONDecodeError as e: + print(f"Invalid JSON format: {e}") + assert False + assert "hash_code" in paras + assert "hash_code" in schema + hash_code = paras["hash_code"] + assert hash_code in CHECK_INFO + info = CHECK_INFO[hash_code] + assert name_beg == info["name"] + assert name_end == info["name"] + assert beg_tag == info["beg_tag"] + for key in info["required"]: + assert key in paras + + +# NOTE: the end-tag format and the hash_code number is been hidden in the SYSTEM_PROMPT. +# By checking whether the end tag and hash code can be generated correctly without any prompts, the correctness of the structural tag can be verified. + +SYSTEM_PROMPT = { + "role": "system", + "content": """ +# Tool Instructions +- Always execute python code in messages that you share. +- When looking for real time information use relevant functions if available else fallback to brave_search +You have access to the following functions: +Use the function 'get_current_weather' to: Get the current weather in a given location +{ + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": { + "type": "string", + }, + }, + "required": ["city", "state", "unit", "hash_code"], + }, +} +Use the function 'get_current_date' to: Get the current date and time for a given timezone +{ + "name": "get_current_date", + "description": "Get the current date and time for a given timezone", + "parameters": { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": { + "type": "string", + }, + }, + "required": ["timezone", "hash_code"], + }, +} +If a you choose to call a function ONLY reply in the following format: +<{start_tag}--->{function_name}|{parameters}|{end_tag}<---{function_name}> +where +start_tag => ` a JSON dict with the function argument name as key and function argument value as value. +Here is an example, +example_function_name|{"example_name": "example_value"}... +or +example_function_name|{"example_name": "example_value"}... +Reminder: +- Function calls MUST follow the specified format +- Required parameters MUST be specified +You are a helpful assistant.""", +} + +STRUCTURAL_TAGS = { + "triggers": ["", ""], + "tags": [ + { + "begin": "get_current_weather|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": {"const": 1234}, + }, + "required": ["city", "state", "unit", "hash_code"], + } + ), + "end": "|End<---get_current_weather>", + }, + { + "begin": "get_current_date|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": {"const": 2345}, + }, + "required": ["timezone", "hash_code"], + } + ), + "end": "|End<---get_current_date>", + }, + { + "begin": "get_current_weather|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "city": { + "type": "string", + "description": "The city to find the weather for, e.g. 'San Francisco'", + }, + "state": { + "type": "string", + "description": "the two-letter abbreviation for the state that the city is" + " in, e.g. 'CA' which would mean 'California'", + }, + "unit": { + "type": "string", + "description": "The unit to fetch the temperature in", + "enum": ["celsius", "fahrenheit"], + }, + "hash_code": {"const": 3456}, + }, + "required": ["city", "state", "unit", "hash_code"], + } + ), + "end": "|End<---get_current_weather>", + }, + { + "begin": "get_current_date|", + "schema": json.dumps( + { + "type": "object", + "properties": { + "timezone": { + "type": "string", + "description": "The timezone to fetch the current date and time for, e.g. 'America/New_York'", + }, + "hash_code": {"const": 4567}, + }, + "required": ["timezone", "hash_code"], + } + ), + "end": "|End<---get_current_date>", + }, + ], +} + +CHECK_INFO = { + 1234: { + "name": "get_current_weather", + "beg_tag": "CALL", + "required": ["city", "state", "unit", "hash_code"], + }, + 2345: { + "name": "get_current_date", + "beg_tag": "CALL", + "required": ["timezone", "hash_code"], + }, + 3456: { + "name": "get_current_weather", + "beg_tag": "call", + "required": ["city", "state", "unit", "hash_code"], + }, + 4567: { + "name": "get_current_date", + "beg_tag": "call", + "required": ["timezone", "hash_code"], + }, +} + +CHAT_COMPLETION_MESSAGES = [ + # messages #0 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current date and time.", + }, + ], + # messages #1 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current weather.", + }, + ], + # messages #2 + [ + SYSTEM_PROMPT, + { + "role": "user", + "content": "You are in New York. Please get the current date and time, and the weather.", + }, + ], +] + + +@pytest.mark.parametrize("stream", [False, True]) +@pytest.mark.parametrize("messages", CHAT_COMPLETION_MESSAGES) +def test_openai_v1_chat_completion_structural_tag( + served_model: str, + launch_server, # pylint: disable=unused-argument + stream: bool, + messages: List[Dict[str, str]], +): + # `served_model` and `launch_server` are pytest fixtures + # defined in conftest.py. + + payload = { + "model": served_model, + "messages": messages, + "stream": stream, + "response_format": { + "type": "structural_tag", + "tags": STRUCTURAL_TAGS["tags"], + "triggers": STRUCTURAL_TAGS["triggers"], + }, + "max_tokens": 1024, + } + + response = requests.post(OPENAI_V1_CHAT_COMPLETION_URL, json=payload, timeout=60) + if not stream: + check_openai_nonstream_response( + response.json(), + model=served_model, + object_str="chat.completion", + num_choices=1, + finish_reason=["stop"], + ) + else: + responses = [] + for chunk in response.iter_lines(chunk_size=512): + if not chunk or chunk == b"data: [DONE]": + continue + responses.append(json.loads(chunk.decode("utf-8")[6:])) + check_openai_stream_response( + responses, + model=served_model, + object_str="chat.completion.chunk", + num_choices=1, + finish_reason="stop", + ) + + print(f"-----------\nCheck for stream={stream} is passed!\n") + + +if __name__ == "__main__": + MODEL = os.environ.get("MLC_SERVE_MODEL") + if MODEL is None: + raise ValueError( + 'Environment variable "MLC_SERVE_MODEL" not found. ' + "Please set it to model compiled by MLC LLM " + "(e.g., `./dist/Llama-2-7b-chat-hf-q0f16-MLC`) " + ) + + for msg in CHAT_COMPLETION_MESSAGES: + test_openai_v1_chat_completion_structural_tag(MODEL, None, stream=False, messages=msg) + test_openai_v1_chat_completion_structural_tag(MODEL, None, stream=True, messages=msg)