Skip to content

Commit 68e9832

Browse files
Add assistant.tool_execution_callback (#884)
* Add assistant.tool_execution_callback * CHANGELOG entry
1 parent 5bae916 commit 68e9832

File tree

4 files changed

+34
-5
lines changed

4 files changed

+34
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- [SECURITY]: A change which fixes a security vulnerability.
1111

1212
## [Unreleased]
13+
- [FEATURE] [https://github.com/patterns-ai-core/langchainrb/pull/884] Add `tool_execution_callback` to `Langchain::Assistant`, a callback function (proc, lambda) that is called right before a tool is executed
1314

1415
## [0.19.1] - 2024-11-21
1516
- [FEATURE] [https://github.com/patterns-ai-core/langchainrb/pull/858] Assistant, when using Anthropic, now also accepts image_url in the message.

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,13 @@ Note that streaming is not currently supported for all LLMs.
536536
* `tool_choice`: Specifies how tools should be selected. Default: "auto". A specific tool function name can be passed. This will force the Assistant to **always** use this function.
537537
* `parallel_tool_calls`: Whether to make multiple parallel tool calls. Default: true
538538
* `add_message_callback`: A callback function (proc, lambda) that is called when any message is added to the conversation (optional)
539+
```ruby
540+
assistant.add_message_callback = -> (message) { puts "New message: #{message}" }
541+
```
542+
* `tool_execution_callback`: A callback function (proc, lambda) that is called right before a tool is executed (optional)
543+
```ruby
544+
assistant.tool_execution_callback = -> (tool_call_id, tool_name, method_name, tool_arguments) { puts "Executing tool_call_id: #{tool_call_id}, tool_name: #{tool_name}, method_name: #{method_name}, tool_arguments: #{tool_arguments}" }
545+
```
539546

540547
### Key Methods
541548
* `add_message`: Adds a user message to the messages array

lib/langchain/assistant.rb

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ class Assistant
2424

2525
attr_accessor :tools,
2626
:add_message_callback,
27+
:tool_execution_callback,
2728
:parallel_tool_calls
2829

2930
# Create a new assistant
@@ -35,14 +36,17 @@ class Assistant
3536
# @param parallel_tool_calls [Boolean] Whether or not to run tools in parallel
3637
# @param messages [Array<Langchain::Assistant::Messages::Base>] The messages
3738
# @param add_message_callback [Proc] A callback function (Proc or lambda) that is called when any message is added to the conversation
39+
# @param tool_execution_callback [Proc] A callback function (Proc or lambda) that is called right before a tool function is executed
3840
def initialize(
3941
llm:,
4042
tools: [],
4143
instructions: nil,
4244
tool_choice: "auto",
4345
parallel_tool_calls: true,
4446
messages: [],
47+
# Callbacks
4548
add_message_callback: nil,
49+
tool_execution_callback: nil,
4650
&block
4751
)
4852
unless tools.is_a?(Array) && tools.all? { |tool| tool.class.singleton_class.included_modules.include?(Langchain::ToolDefinition) }
@@ -52,11 +56,8 @@ def initialize(
5256
@llm = llm
5357
@llm_adapter = LLM::Adapter.build(llm)
5458

55-
# TODO: Validate that it is, indeed, a Proc or lambda
56-
if !add_message_callback.nil? && !add_message_callback.respond_to?(:call)
57-
raise ArgumentError, "add_message_callback must be a callable object, like Proc or lambda"
58-
end
59-
@add_message_callback = add_message_callback
59+
@add_message_callback = add_message_callback if validate_callback!("add_message_callback", add_message_callback)
60+
@tool_execution_callback = tool_execution_callback if validate_callback!("tool_execution_callback", tool_execution_callback)
6061

6162
self.messages = messages
6263
@tools = tools
@@ -359,6 +360,8 @@ def run_tools(tool_calls)
359360
t.class.tool_name == tool_name
360361
end or raise ArgumentError, "Tool: #{tool_name} not found in assistant.tools"
361362

363+
# Call the callback if set
364+
tool_execution_callback.call(tool_call_id, tool_name, method_name, tool_arguments) if tool_execution_callback # rubocop:disable Style/SafeNavigation
362365
output = tool_instance.send(method_name, **tool_arguments)
363366

364367
submit_tool_output(tool_call_id: tool_call_id, output: output)
@@ -392,5 +395,13 @@ def record_used_tokens(prompt_tokens, completion_tokens, total_tokens_from_opera
392395
def available_tool_names
393396
llm_adapter.available_tool_names(tools)
394397
end
398+
399+
def validate_callback!(attr_name, callback)
400+
if !callback.nil? && !callback.respond_to?(:call)
401+
raise ArgumentError, "#{attr_name} must be a callable object, like Proc or lambda"
402+
end
403+
404+
true
405+
end
395406
end
396407
end

spec/langchain/assistant/assistant_spec.rb

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,16 @@
2020
end
2121
end
2222

23+
describe "#tool_execution_callback" do
24+
it "raises an error if the callback is not a Proc" do
25+
expect { described_class.new(llm: llm, tool_execution_callback: "foo") }.to raise_error(ArgumentError)
26+
end
27+
28+
it "does not raise an error if the callback is a Proc" do
29+
expect { described_class.new(llm: llm, tool_execution_callback: -> {}) }.not_to raise_error
30+
end
31+
end
32+
2333
it "raises an error if LLM class does not implement `chat()` method" do
2434
llm = Langchain::LLM::Replicate.new(api_key: "123")
2535
expect { described_class.new(llm: llm) }.to raise_error(ArgumentError)

0 commit comments

Comments
 (0)