Skip to content

[Serving] Add Structural-Tag api to RequestResponseFormat #3187

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion 3rdparty/tvm
Submodule tvm updated 111 files
2 changes: 1 addition & 1 deletion 3rdparty/xgrammar
Submodule xgrammar updated 158 files
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ add_subdirectory(${TOKENZIER_CPP_PATH} tokenizers EXCLUDE_FROM_ALL)
set(XGRAMMAR_PATH 3rdparty/xgrammar)
tvm_file_glob(GLOB_RECURSE MLC_LLM_SRCS cpp/*.cc)
tvm_file_glob(GLOB_RECURSE XGRAMMAR_SRCS ${XGRAMMAR_PATH}/cpp/*.cc)
list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/pybind/.*\\.cc")
list(FILTER XGRAMMAR_SRCS EXCLUDE REGEX "${XGRAMMAR_PATH}/cpp/nanobind/.*\\.cc")
list(APPEND MLC_LLM_SRCS ${XGRAMMAR_SRCS})
add_library(mlc_llm_objs OBJECT ${MLC_LLM_SRCS})

Expand Down
54 changes: 51 additions & 3 deletions cpp/serve/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,13 +42,43 @@ Result<ResponseFormat> ResponseFormat::FromJSON(const picojson::object& config)
ResponseFormat res;
res.type = json::LookupOrDefault<std::string>(config, "type", "text");

if (res.type != "text" && res.type != "function" && res.type != "json_object" &&
res.type != "structural_tag") {
return TResult::Error("Uknonwn response_format type " + res.type);
}

std::optional<std::string> schema = json::LookupOptional<std::string>(config, "schema");
if (schema.has_value()) {
res.schema = schema.value();
}

if (res.type != "text" && res.type != "function" && res.type != "json_object") {
return TResult::Error("Uknonwn response_format type " + res.type);
if (auto tags_obj = json::LookupOptional<picojson::array>(config, "tags")) {
auto tags = Array<Array<String>>();
for (auto tag_obj : tags_obj.value()) {
Array<String> tag = Array<String>();
std::optional<std::string> begin =
json::LookupOptional<std::string>(tag_obj.get<picojson::object>(), "begin");
std::optional<std::string> schema =
json::LookupOptional<std::string>(tag_obj.get<picojson::object>(), "schema");
std::optional<std::string> end =
json::LookupOptional<std::string>(tag_obj.get<picojson::object>(), "end");
if (!(begin.has_value() && schema.has_value() && end.has_value())) {
return TResult::Error("Miss tag attribute.");
}
tag.push_back(begin.value());
tag.push_back(schema.value());
tag.push_back(end.value());
tags.push_back(std::move(tag));
}
res.tags = tags;
}

if (auto triggers_obj = json::LookupOptional<picojson::array>(config, "triggers")) {
auto triggers = Array<String>();
for (auto trigger : triggers_obj.value()) {
triggers.push_back(trigger.get<std::string>());
}
res.triggers = triggers;
}

return TResult::Ok(res);
Expand All @@ -60,6 +90,24 @@ picojson::object ResponseFormat::AsJSON() const {
if (schema.defined()) {
config["schema"] = picojson::value(schema.value().operator std::string());
}
if (tags.defined()) {
picojson::array tags_obj = picojson::array();
for (auto tag : tags.value()) {
picojson::array tag_obj = picojson::array();
tag_obj.emplace_back(tag[0]);
tag_obj.emplace_back(tag[1]);
tag_obj.emplace_back(tag[2]);
tags_obj.emplace_back(tag_obj);
}
config["tags"] = picojson::value(tags_obj);
}
if (triggers.defined()) {
picojson::array trigger_obj = picojson::array();
for (std::string trigger : triggers.value()) {
trigger_obj.emplace_back(trigger);
}
config["triggers"] = picojson::value(trigger_obj);
}
return config;
}

Expand Down Expand Up @@ -1073,4 +1121,4 @@ Result<bool> ModelsUseKVCache(const std::vector<picojson::object>& model_configs

} // namespace serve
} // namespace llm
} // namespace mlc
} // namespace mlc
4 changes: 3 additions & 1 deletion cpp/serve/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ using namespace tvm::runtime;
struct ResponseFormat {
String type = "text";
Optional<String> schema = NullOpt;
Optional<Array<Array<String>>> tags = NullOpt;
Optional<Array<String>> triggers = NullOpt;
/*!
* \brief Create debug config from JSON.
* \param config_json The json string for generation config
Expand Down Expand Up @@ -448,4 +450,4 @@ inline PrefillMode PrefillModeFromString(const std::string& prefill_mode) {
} // namespace llm
} // namespace mlc

#endif // MLC_LLM_SERVE_CONFIG_H_
#endif // MLC_LLM_SERVE_CONFIG_H_
30 changes: 22 additions & 8 deletions cpp/serve/engine.cc
Original file line number Diff line number Diff line change
Expand Up @@ -463,9 +463,11 @@ class EngineImpl : public Engine {
ModelWorkspace{model->AllocEmbeddingTensor(), model->AllocHiddenStatesTensor()});
}
// - Initialize tokenizer and grammar

n->tokenizer_ = Tokenizer::FromPath(engine_config->model, GetTokenizerInfo(model_configs[0]));
n->token_table_ = n->tokenizer_->PostProcessedTokenTable();
n->cached_grammar_compiler_ = xgrammar::CachedGrammarCompiler(n->token_table_);
// TODO: check 'vocab_size' of TokenizerInfo
n->grammar_compiler_ = xgrammar::GrammarCompiler(xgrammar::TokenizerInfo(n->token_table_));
// - Create the logit processor and sampler, and
// the DraftTokenWorkspaceManager for speculative decoding.
int max_num_tokens = engine_config->max_num_sequence;
Expand Down Expand Up @@ -975,13 +977,25 @@ class EngineImpl : public Engine {
* is not JSON, return std::nullopt. */
std::optional<xgrammar::CompiledGrammar> GetGrammarFromResponseFormat(
const ResponseFormat& response_format) {
if (response_format.type != "json_object") {
// TODO: add other grammar type
if (response_format.type == "text") {
return std::nullopt;
} else if (!response_format.schema) {
return cached_grammar_compiler_.GetCompiledGrammarForJSON();
} else if (response_format.type == "json_object") {
if (!response_format.schema) {
return grammar_compiler_.CompileBuiltinJSONGrammar();
} else {
return grammar_compiler_.CompileJSONSchema(response_format.schema.value());
}
} else {
return cached_grammar_compiler_.GetCompiledGrammarForJSONSchema(
response_format.schema.value());
std::vector<xgrammar::StructuralTagItem> tags;
std::vector<std::string> triggers;
for (auto tag : response_format.tags.value()) {
tags.emplace_back(xgrammar::StructuralTagItem{tag[0], tag[1], tag[2]});
}
for (auto trigger : response_format.triggers.value()) {
triggers.emplace_back(trigger);
}
return grammar_compiler_.CompileStructuralTag(std::move(tags), std::move(triggers));
}
}

Expand All @@ -992,8 +1006,8 @@ class EngineImpl : public Engine {
// internal tokenizer
Tokenizer tokenizer_;
std::vector<std::string> token_table_;
// Cached grammar compiler for grammar matching.
xgrammar::CachedGrammarCompiler cached_grammar_compiler_;
// Grammar compiler for grammar matching.
xgrammar::GrammarCompiler grammar_compiler_;
// Models
Array<Model> models_;
// Device that the models run on.
Expand Down
4 changes: 2 additions & 2 deletions cpp/serve/request_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ RequestModelState::RequestModelState(
if (compiled_grammar.has_value()) {
// TODO(yixin): set rollback limit to a configurable value.
n->grammar_matcher =
xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, std::nullopt, 10);
xgrammar::GrammarMatcher(compiled_grammar.value(), std::nullopt, false, 10);
}

n->request = std::move(request);
Expand All @@ -44,7 +44,7 @@ bool RequestModelStateNode::RequireNextTokenBitmask() { return grammar_matcher.h
void RequestModelStateNode::GetNextTokenBitmask(DLTensor* bitmask) {
ICHECK(grammar_matcher.has_value());

grammar_matcher->GetNextTokenBitmask(bitmask);
grammar_matcher->FillNextTokenBitmask(bitmask);
}

void RequestModelStateNode::CommitToken(SampleResult sampled_token) {
Expand Down
27 changes: 25 additions & 2 deletions python/mlc_llm/protocol/openai_api_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,12 +86,35 @@ class ModelResponse(BaseModel):


class RequestResponseFormat(BaseModel):
type: Literal["text", "json_object"] = "text"
json_schema: Optional[str] = Field(default=None, alias="schema")
type: Literal["text", "json_object", "structural_tag"] = "text"
"""This field is named json_schema instead of schema because BaseModel defines a method called
schema. During construction of RequestResponseFormat, key "schema" still should be used:
`RequestResponseFormat(type="json_object", schema="{}")`
"""
json_schema: Optional[str] = Field(default=None, alias="schema")

"""These field are only used for type="structural_tag"."""
tags: Optional[List[Dict[str, str]]] = Field(default=None, alias="tags")
triggers: Optional[List[str]] = Field(default=None, alias="triggers")

@model_validator(mode="after")
def check_request_response_format(self) -> "RequestResponseFormat":
"""Check if the RequestResponseFormat is valid."""
if self.type == "structural_tag":
if self.tags is None or self.triggers is None:
raise ValueError("structural_tag type must contain keys 'tags' and 'triggers'.")
for tag in self.tags:
if set(tag.keys()) != {"begin", "schema", "end"}:
raise ValueError(
"Each tag must contain exactly 'begin', 'schema' and 'end' keys."
f"Got keys: {list(tag.keys())}."
)
elif self.tags is not None or self.triggers is not None:
raise Warning(
"'tags' and 'triggers' attributes should be used when type='structural_tag'"
)

return self


class CompletionRequest(BaseModel):
Expand Down
Loading