Skip to content

Commit eeb7363

Browse files
Rose22LostRuins
andauthored
improvements to tool calling logic (merged changes from old PR branch) (#1855)
* improvements to tool calling logic (merged changes from old PR branch) * added some tweaks for improved tool calls to reuse old ctx, but needs testing. refer to PR. * fixes to some stuff that concedo's modifications broke * fixed error in reasoning * extremely hacky way to cache tool list please fix * oops forgot to add this * slightly less hacky way to preserve the tool list in context * prevented unintended toolcalls from happening when LLM states something irrelevant to toolcall decision * fixed something that broke koboldlite * fixed bug added by concedo that broke jinja tools * experimental further compression of tools array, needs testing * reverted experimental further compression of tools array * final cleanup * add newline after memory insert * changed tool reasoning to always be in json format to enforce including final decision * used new json format to skip extra llm call when not necessary * more catching of possible bad llm output * further cleanup * got it down to just one llm call! * better json format * even better json format * further refinement to json format * further refinement to json format * fixed broken tool calling * single-call enforced json method now seems to work well. removed fallbacks as they are no longer required. --------- Co-authored-by: Concedo <[email protected]>
1 parent 2ef03a8 commit eeb7363

File tree

1 file changed

+105
-52
lines changed

1 file changed

+105
-52
lines changed

koboldcpp.py

100644100755
Lines changed: 105 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -2527,34 +2527,82 @@ def determine_tool_json_to_use(genparams, curr_ctx, assistant_message_start, is_
25272527
# tools handling: Check if user is passing a openai tools array, if so add to end of prompt before assistant prompt unless tool_choice has been set to None
25282528
tools_array = genparams.get('tools', [])
25292529
chosen_tool = genparams.get('tool_choice', "auto")
2530+
messages = genparams.get('messages',[])
2531+
toolmem = genparams.get("memory","")
2532+
25302533
# first handle auto mode, determine whether a tool is needed
25312534
used_tool_json = None
25322535
if not curr_ctx:
25332536
return None
2537+
2538+
# get user's last message and last tool call results
2539+
last_user_message = ""
2540+
tool_call_results = ""
2541+
2542+
if messages:
2543+
reversed_messages = list(reversed(messages))
2544+
for message in reversed_messages:
2545+
if message["role"] == "user":
2546+
last_user_message = message["content"]
2547+
last_user_message = f"\n\nUser's current request: {last_user_message}"
2548+
break
2549+
tool_call_chunk = []
2550+
for message in reversed_messages:
2551+
if message["role"] == "tool":
2552+
tool_call_chunk.append(message["content"])
2553+
else:
2554+
break
2555+
tmp_tool_replies = list(reversed(tool_call_chunk))
2556+
if tmp_tool_replies and len(tmp_tool_replies)>0:
2557+
tool_call_results = f"\n\nTool call responses: {tmp_tool_replies}"
2558+
25342559
if tools_array and len(tools_array) > 0 and chosen_tool is not None and chosen_tool!="none":
2535-
tools_string = json.dumps(tools_array, indent=0)
25362560
should_use_tools = True
2537-
if chosen_tool=="auto":
2538-
# if you want a different template, you can set 'custom_tools_prompt' in the chat completions adapter as follows
2539-
custom_tools_prompt = "Can the user query be answered by a listed tool above? (One word response: yes or no):"
2540-
if is_followup_tool:
2541-
custom_tools_prompt = "Can the user query be further answered by another listed tool above? (If response is already complete, reply NO) (One word response: yes or no):"
2561+
if chosen_tool=="auto" or chosen_tool=="required":
25422562
# note: message string already contains the instruct start tag!
2543-
pollgrammar = r'root ::= "yes" | "no" | "Yes" | "No" | "YES" | "NO"'
2563+
temptoolnames = extract_all_names_from_tool_array(tools_array)
2564+
tempjson = {}
2565+
if chosen_tool=="required":
2566+
custom_tools_prompt_json_format = "Respond with a JSON object using this structure:\r\n{\r\n \"tool_name\": \"exact_tool_name_here\"\r\n}\r\n\r\nRules:\r\n- You must pick one of the tools to use, pick the most suitable tool."
2567+
tempjson = {"type":"object","properties":{"tool_name":{"type":"string","enum":temptoolnames}},"required":["tool_name"],"additionalProperties":False}
2568+
else:
2569+
temptoolnames.append("null")
2570+
custom_tools_prompt_json_format = "Respond with a JSON object using this structure:\r\n{\r\n \"reasoning\": \"Your reasoning here\",\r\n \"final_decision\": \"yes\" or \"no\",\r\n \"tool_name\": \"exact_tool_name_here\" or \"null\"\r\n}\r\n\r\nRules:\r\n- Output only the JSON object. Do NOT add anything before or after the json object.\r\n- final_decision must be exactly \"yes\" or \"no\"\r\n- tool_name must be either an exact tool name, or if no tool is required, an empty string: \"\"\r\n- Keep reasoning short, maximum one or two sentences.\r\n- No unnecessary comments"
2571+
tempjson = {"type":"object","properties":{"reasoning":{"type":"string"},"final_decision":{"type":"string","enum":["yes","no","Yes","No","YES","NO"," yes"," no"," Yes"," No"," YES"," NO"]},"tool_name":{"type":"string","enum":temptoolnames}},"required":["reasoning","final_decision","tool_name"],"additionalProperties":False}
2572+
toolquerygrammar = convert_json_to_gbnf(tempjson)
2573+
2574+
if not is_followup_tool:
2575+
custom_tools_prompt = "Is calling one of the tools listed above absolutely essential to answer user's current request, or is a tool call optional?"
2576+
custom_tools_prompt_processed = f"{curr_ctx}{last_user_message}\n\n{custom_tools_prompt} {custom_tools_prompt_json_format}{assistant_message_start}"
2577+
else:
2578+
custom_tools_prompt = "Given the tool call response to the user's current request, is another tool call needed to further answer user's message?"
2579+
custom_tools_prompt_processed = f"{curr_ctx}{last_user_message}{tool_call_results}\n\n{custom_tools_prompt} {custom_tools_prompt_json_format}{assistant_message_start}"
2580+
2581+
# first, prompt to see if a tool call is needed using the prompt above.
2582+
# the result is a short explanation by the LLM on why a tool call is or is not needed, along with it's final decision at the end.
25442583
temp_poll = {
2545-
"prompt": f"{curr_ctx}\n\nTool List:\n{tools_string}\n\n{custom_tools_prompt}{assistant_message_start}",
2546-
"max_length":5,
2584+
"prompt": custom_tools_prompt_processed,
2585+
"memory": toolmem,
2586+
"max_length":300,
25472587
"temperature":0.1,
25482588
"top_k":1,
25492589
"rep_pen":1,
25502590
"ban_eos_token":False,
2551-
"grammar":pollgrammar
2552-
}
2591+
"grammar":toolquerygrammar
2592+
}
25532593
temp_poll_result = generate(genparams=temp_poll)
2554-
if temp_poll_result and "yes" not in temp_poll_result['text'].lower():
2555-
should_use_tools = False
2594+
temp_poll_text = temp_poll_result['text'].strip().rstrip('.')
2595+
temp_poll_data_arr = extract_json_from_string(temp_poll_text)
2596+
temp_poll_data = temp_poll_data_arr[0] if (temp_poll_data_arr and len(temp_poll_data_arr)>0) else None
2597+
2598+
if temp_poll_data:
2599+
if chosen_tool!="required" and ("yes" not in temp_poll_data.get("final_decision","").lower() or "null" in temp_poll_data.get("tool_name","").lower()):
2600+
should_use_tools = False
2601+
elif (chosen_tool=="auto" or chosen_tool=="required") and "null" not in temp_poll_data.get("tool_name","").lower():
2602+
chosen_tool = temp_poll_data.get("tool_name","").lower().strip()
2603+
25562604
if not args.quiet:
2557-
print(f"\nRelevant tool is listed: {temp_poll_result['text']} ({should_use_tools})")
2605+
print(f"\n[TOOLCALL REASONING]: {temp_poll_text}")
25582606

25592607
if should_use_tools:
25602608
#first, try and extract a specific tool if selected
@@ -2567,38 +2615,26 @@ def determine_tool_json_to_use(genparams, curr_ctx, assistant_message_start, is_
25672615
toolnames = extract_all_names_from_tool_array(tools_array)
25682616
if len(toolnames) == 1:
25692617
used_tool_json = extract_tool_info_from_tool_array(toolnames[0], tools_array)
2570-
else:
2571-
pollgrammar = ""
2572-
for name in toolnames:
2573-
pollgrammar += ("" if pollgrammar=="" else " | ")
2574-
pollgrammar += "\"" + name + "\""
2575-
pollgrammar += " | \"no_tool\""
2576-
pollgrammar = r'root ::= ' + pollgrammar
2577-
decide_tool_prompt = "Which of the listed tools should be used next? Pick exactly one. If no tool is suitable, reply no_tool. (Reply directly with the selected tool's name):"
2578-
temp_poll = {
2579-
"prompt": f"{curr_ctx}\n\nTool List:\n{tools_string}\n\n{decide_tool_prompt}{assistant_message_start}",
2580-
"max_length":16,
2581-
"temperature":0.1,
2582-
"top_k":1,
2583-
"rep_pen":1,
2584-
"ban_eos_token":False,
2585-
"grammar":pollgrammar
2586-
}
2587-
temp_poll_result = generate(genparams=temp_poll)
2588-
if temp_poll_result:
2589-
raw = temp_poll_result['text'].lower()
2590-
if "no_tool" in raw:
2591-
print(f"\nNo suitable tool found.")
2592-
else:
2593-
for name in toolnames:
2594-
if name.lower() in raw:
2595-
used_tool_json = extract_tool_info_from_tool_array(name, tools_array)
2596-
if not args.quiet:
2597-
print(f"\nAttempting to use tool: {name}")
2598-
break
25992618

26002619
return used_tool_json
26012620

2621+
def compress_tools_array(tools_array):
2622+
tools_array_filtered = []
2623+
for tool_dict in tools_array:
2624+
tool_data = tool_dict['function']
2625+
tool_props = {}
2626+
params = tool_data.get("parameters", {})
2627+
props = params.get("properties", {})
2628+
for prop_name, prop_data in props.items():
2629+
tool_props[prop_name] = prop_data['type']
2630+
tools_array_filtered.append({
2631+
"name": tool_data['name'],
2632+
"description": tool_data['description'],
2633+
"properties": tool_props
2634+
})
2635+
2636+
return tools_array_filtered
2637+
26022638
def transform_genparams(genparams, api_format, use_jinja):
26032639
global chatcompl_adapter, maxctx
26042640

@@ -2704,7 +2740,7 @@ def transform_genparams(genparams, api_format, use_jinja):
27042740
assistant_message_start = adapter_obj.get("assistant_start", "\n### Response:\n")
27052741
assistant_message_end = adapter_obj.get("assistant_end", "")
27062742
assistant_message_gen = adapter_obj.get("assistant_gen", assistant_message_start)
2707-
tools_message_start = adapter_obj.get("tools_start", "\nTool Results:\n")
2743+
tools_message_start = adapter_obj.get("tools_start", "")
27082744
tools_message_end = adapter_obj.get("tools_end", "")
27092745
images_added = []
27102746
audio_added = []
@@ -2747,6 +2783,13 @@ def transform_genparams(genparams, api_format, use_jinja):
27472783
if jinjatools and len(jinjatools)>0:
27482784
genparams["using_openai_tools"] = True
27492785
else:
2786+
if jinjatools:
2787+
# inject the tools list at the top of the context window, even if context has shifted
2788+
# uses koboldcpp's special memory parameter
2789+
tools_string = f"{system_message_start}### Available Tools:\n{json.dumps(compress_tools_array(jinjatools), indent=0)}{system_message_end}\n"
2790+
exist_mem = genparams.get('memory', "")
2791+
genparams["memory"] = tools_string + exist_mem
2792+
27502793
for message in messages_array:
27512794
message_index += 1
27522795
if message['role'] == "system":
@@ -2757,6 +2800,9 @@ def transform_genparams(genparams, api_format, use_jinja):
27572800
messages_string += assistant_message_start
27582801
elif message['role'] == "tool":
27592802
messages_string += tools_message_start
2803+
tcid = message.get("tool_call_id","")
2804+
tcid = ("" if not tcid else f" {tcid}")
2805+
messages_string += f"\nReceived results of function call{tcid}:\n"
27602806

27612807
# content can be a string or an array of objects
27622808
curr_content = message.get("content",None)
@@ -2768,9 +2814,16 @@ def transform_genparams(genparams, api_format, use_jinja):
27682814
if not curr_content:
27692815
if "tool_calls" in message:
27702816
try:
2771-
if len(message.get("tool_calls"))>0:
2772-
tcfnname = message.get("tool_calls")[0].get("function").get("name")
2773-
messages_string += f"\n(Made a function call to {tcfnname})\n"
2817+
nlstart = True
2818+
for tc in message.get("tool_calls"):
2819+
if nlstart:
2820+
nlstart = False
2821+
messages_string += "\n"
2822+
tcid = tc.get("id","")
2823+
tcfnname = tc.get("function").get("name")
2824+
tcfnargs = tc.get("function").get("arguments","")
2825+
tcfnargs = (f" with arguments={tcfnargs}" if tcfnargs else "")
2826+
messages_string += f"(Made a function call {tcid} to {tcfnname}{tcfnargs})\n"
27742827
except Exception:
27752828
messages_string += "\n(Made a function call)\n"
27762829
pass # do nothing
@@ -2815,7 +2868,6 @@ def transform_genparams(genparams, api_format, use_jinja):
28152868
tool_json_formatting_instruction = f"\nPlease use the provided schema to fill the parameters to create a function call for {toolname}, in the following format: " + json.dumps([{"id": "call_001", "type": "function", "function": {"name": f"{toolname}", "arguments": {"first property key": "first property value", "second property key": "second property value"}}}], indent=0)
28162869
messages_string += f"\n\nJSON Schema:\n{used_tool_json}\n\n{tool_json_formatting_instruction}{assistant_message_start}"
28172870

2818-
28192871
if message['role'] == "system":
28202872
messages_string += system_message_end
28212873
elif message['role'] == "user":
@@ -4210,10 +4262,6 @@ def do_POST(self):
42104262
is_embeddings = False
42114263
response_body = None
42124264
use_jinja = args.jinja
4213-
if use_jinja and not args.jinja_tools:
4214-
tmptools = genparams.get('tools', [])
4215-
if tmptools and len(tmptools) > 0:
4216-
use_jinja = False # not allowed to use tools with jinja
42174265

42184266
if self.path.endswith('/api/admin/check_state'):
42194267
if global_memory and args.admin and args.admindir and os.path.exists(args.admindir) and self.check_header_password(args.adminpassword):
@@ -4349,6 +4397,11 @@ def do_POST(self):
43494397
if args.debugmode >= 1:
43504398
trunc_len = 32000
43514399

4400+
if use_jinja and not args.jinja_tools:
4401+
tmptools = genparams.get('tools', [])
4402+
if tmptools and len(tmptools) > 0:
4403+
use_jinja = False # not allowed to use tools with jinja
4404+
43524405
printablegenparams_raw = truncate_long_json(genparams,trunc_len)
43534406
utfprint("\nInput: " + json.dumps(printablegenparams_raw,ensure_ascii=False),1)
43544407

0 commit comments

Comments
 (0)