Skip to content

Commit 64edafd

Browse files
authored
Add more metadata to the provider object (#407)
* Include a name for logging and printing * Include the model, since we use it just about everywhere `Provider` no longer seems like quite the right name for this class, but I think we can leave it as is for now. Fixes #406
1 parent d1409c0 commit 64edafd

28 files changed

+119
-57
lines changed

NEWS.md

+3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
# ellmer (development version)
22

3+
* `Provider` gains `name` and `model` fields (#406). These are now reported when
4+
you print a chat object and used in `token_usage()`.
5+
36
* New `interpolate_package()` to make it easier to interpolate from prompts
47
stored in the `inst/prompts` inside a package (#164).
58

R/chat.R

+5-7
Original file line numberDiff line numberDiff line change
@@ -641,20 +641,18 @@ is_chat <- function(x) {
641641

642642
#' @export
643643
print.Chat <- function(x, ...) {
644+
provider <- x$get_provider()
644645
turns <- x$get_turns(include_system_prompt = TRUE)
645646

646647
tokens <- x$tokens(include_system_prompt = TRUE)
647648
tokens_user <- sum(tokens$tokens[tokens$role == "user"])
648649
tokens_assistant <- sum(tokens$tokens[tokens$role == "assistant"])
649650

650-
cat(paste0(
651+
cat(paste_c(
651652
"<Chat",
652-
" turns=",
653-
length(turns),
654-
" tokens=",
655-
tokens_user,
656-
"/",
657-
tokens_assistant,
653+
c(" ", provider@name, "/", provider@model),
654+
c(" turns=", length(turns)),
655+
c(" tokens=", tokens_user, "/", tokens_assistant),
658656
">\n"
659657
))
660658

R/provider-azure.R

+1
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ chat_azure <- function(
8989
credentials <- credentials %||% default_azure_credentials(api_key, token)
9090

9191
provider <- ProviderAzure(
92+
name = "Azure/OpenAI",
9293
base_url = paste0(endpoint, "/openai/deployments/", deployment_id),
9394
model = deployment_id,
9495
params = params,

R/provider-bedrock.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ chat_bedrock <- function(
6666
echo <- check_echo(echo)
6767

6868
provider <- ProviderBedrock(
69+
name = "AWS/Bedrock",
6970
base_url = "",
7071
model = model,
7172
profile = profile,
@@ -81,7 +82,6 @@ ProviderBedrock <- new_class(
8182
"ProviderBedrock",
8283
parent = Provider,
8384
properties = list(
84-
model = prop_string(),
8585
profile = prop_string(allow_null = TRUE),
8686
region = prop_string(),
8787
cache = class_list
@@ -264,7 +264,7 @@ method(value_turn, ProviderBedrock) <- function(
264264
})
265265

266266
tokens <- c(result$usage$inputTokens, result$usage$outputTokens)
267-
tokens_log("Bedrock", tokens)
267+
tokens_log(provider, tokens)
268268

269269
Turn(result$output$message$role, contents, json = result, tokens = tokens)
270270
}

R/provider-claude.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ chat_claude <- function(
5252
}
5353

5454
provider <- ProviderClaude(
55+
name = "Anthropic",
5556
model = model,
5657
params = params %||% params(),
5758
extra_args = api_args,
@@ -81,7 +82,6 @@ ProviderClaude <- new_class(
8182
parent = Provider,
8283
properties = list(
8384
api_key = prop_string(),
84-
model = prop_string(),
8585
beta_headers = class_character
8686
)
8787
)
@@ -279,7 +279,7 @@ method(value_turn, ProviderClaude) <- function(
279279
})
280280

281281
tokens <- c(result$usage$input_tokens, result$usage$output_tokens)
282-
tokens_log("Claude", tokens)
282+
tokens_log(provider, tokens)
283283

284284
Turn(result$role, contents, json = result, tokens = tokens)
285285
}

R/provider-cortex.R

+17-1
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ chat_cortex_analyst <- function(
8080
credentials <- credentials %||% default_snowflake_credentials(account)
8181

8282
provider <- ProviderCortex(
83+
name = "Snowflake/CortexAnalyst",
8384
account = account,
8485
credentials = credentials,
8586
model_spec = model_spec,
@@ -124,6 +125,7 @@ ProviderCortex <- new_class(
124125
"ProviderCortex",
125126
parent = Provider,
126127
constructor = function(
128+
name,
127129
account,
128130
credentials,
129131
model_spec = NULL,
@@ -137,7 +139,12 @@ ProviderCortex <- new_class(
137139
!!!extra_args
138140
))
139141
new_object(
140-
Provider(base_url = base_url, extra_args = extra_args),
142+
Provider(
143+
name = name,
144+
base_url = base_url,
145+
extra_args = extra_args,
146+
model = ""
147+
),
141148
account = account,
142149
credentials = credentials
143150
)
@@ -149,6 +156,15 @@ ProviderCortex <- new_class(
149156
)
150157
)
151158

159+
provider_cortex_test <- function(..., credentials = function(account) list()) {
160+
ProviderCortex(
161+
name = "Cortex",
162+
account = "testorg-test_account",
163+
credentials = credentials,
164+
...
165+
)
166+
}
167+
152168
# See: https://docs.snowflake.com/en/developer-guide/snowflake-rest-api/reference/cortex-analyst
153169
# https://docs.snowflake.com/en/user-guide/snowflake-cortex/cortex-analyst/tutorials/tutorial-1#step-3-create-a-streamlit-app-to-talk-to-your-data-through-cortex-analyst
154170
method(chat_request, ProviderCortex) <- function(

R/provider-databricks.R

+1
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ chat_databricks <- function(
6969
credentials <- default_databricks_credentials(workspace)
7070
}
7171
provider <- ProviderDatabricks(
72+
name = "Databricks",
7273
base_url = workspace,
7374
model = model,
7475
extra_args = api_args,

R/provider-deepseek.R

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ chat_deepseek <- function(
4141
}
4242

4343
provider <- ProviderDeepSeek(
44+
name = "DeepSeek",
4445
base_url = base_url,
4546
model = model,
4647
seed = seed,

R/provider-gemini.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ chat_gemini <- function(
4343
credentials <- default_google_credentials(api_key)
4444

4545
provider <- ProviderGemini(
46+
name = "Google/Gemini",
4647
base_url = base_url,
4748
model = model,
4849
params = params %||% params(),
@@ -205,7 +206,7 @@ method(value_turn, ProviderGemini) <- function(
205206
usage$promptTokenCount %||% NA_integer_,
206207
usage$candidatesTokenCount %||% NA_integer_
207208
)
208-
tokens_log("Gemini", tokens)
209+
tokens_log(provider, tokens)
209210

210211
Turn("assistant", contents, json = result, tokens = tokens)
211212
}

R/provider-groq.R

+1
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ chat_groq <- function(
4242
}
4343

4444
provider <- ProviderGroq(
45+
name = "Groq",
4546
base_url = base_url,
4647
model = model,
4748
seed = seed,

R/provider-ollama.R

+1
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ chat_ollama <- function(
5454
echo <- check_echo(echo)
5555

5656
provider <- ProviderOllama(
57+
name = "Ollama",
5758
base_url = file.path(base_url, "v1"), ## the v1 portion of the path is added for openAI compatible API
5859
model = model,
5960
seed = seed,

R/provider-openai.R

+2-5
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ chat_openai <- function(
7272
}
7373

7474
provider <- ProviderOpenAI(
75+
name = "OpenAI",
7576
base_url = base_url,
7677
model = model,
7778
params = params,
@@ -96,7 +97,6 @@ ProviderOpenAI <- new_class(
9697
parent = Provider,
9798
properties = list(
9899
api_key = prop_string(),
99-
model = prop_string(),
100100
# no longer used by OpenAI itself; but subclasses still need it
101101
seed = prop_number_whole(allow_null = TRUE)
102102
)
@@ -241,10 +241,7 @@ method(value_turn, ProviderOpenAI) <- function(
241241
result$usage$prompt_tokens %||% NA_integer_,
242242
result$usage$completion_tokens %||% NA_integer_
243243
)
244-
tokens_log(
245-
paste0("OpenAI-", gsub("https?://", "", provider@base_url)),
246-
tokens
247-
)
244+
tokens_log(provider, tokens)
248245

249246
Turn(message$role, content, json = result, tokens = tokens)
250247
}

R/provider-openrouter.R

+1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ chat_openrouter <- function(
3737
}
3838

3939
provider <- ProviderOpenRouter(
40+
name = "OpenRouter",
4041
base_url = "https://openrouter.ai/api/v1",
4142
model = model,
4243
seed = seed,

R/provider-snowflake.R

+2-1
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ chat_snowflake <- function(
5858
credentials <- credentials %||% default_snowflake_credentials(account)
5959

6060
provider <- ProviderSnowflake(
61+
name = "Snowflake/Cortex",
6162
base_url = snowflake_url(account),
6263
account = account,
6364
credentials = credentials,
@@ -145,7 +146,7 @@ method(value_turn, ProviderSnowflake) <- function(
145146
result$usage$prompt_tokens %||% NA_integer_,
146147
result$usage$completion_tokens %||% NA_integer_
147148
)
148-
tokens_log(paste0("Snowflake-", provider@account), tokens)
149+
tokens_log(provider, tokens)
149150
Turn(
150151
# Snowflake's response format seems to omit the role.
151152
"assistant",

R/provider-vllm.R

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ chat_vllm <- function(
4141
echo <- check_echo(echo)
4242

4343
provider <- ProviderVllm(
44+
name = "VLLM",
4445
base_url = base_url,
4546
model = model,
4647
seed = seed,

R/provider.R

+13-1
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,33 @@ NULL
1313
#' the various generics that control the behavior of each provider.
1414
#'
1515
#' @export
16+
#' @param name Name of the provider.
17+
#' @param model Name of the model.
1618
#' @param base_url The base URL for the API.
1719
#' @param params A list of standard parameters created by [params()].
1820
#' @param extra_args Arbitrary extra arguments to be included in the request body.
1921
#' @return An S7 Provider object.
2022
#' @examples
21-
#' Provider(base_url = "https://cool-models.com")
23+
#' Provider(
24+
#' name = "CoolModels",
25+
#' model = "my_model",
26+
#' base_url = "https://cool-models.com"
27+
#' )
2228
Provider <- new_class(
2329
"Provider",
2430
properties = list(
31+
name = prop_string(),
32+
model = prop_string(),
2533
base_url = prop_string(),
2634
params = class_list,
2735
extra_args = class_list
2836
)
2937
)
3038

39+
test_provider <- function(name = "", model = "", base_url = "", ...) {
40+
Provider(name = name, model = model, base_url = base_url, ...)
41+
}
42+
3143
# Create a request------------------------------------
3244

3345
chat_request <- new_generic(

R/tokens.R

+14-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1-
tokens_log <- function(name, tokens) {
1+
tokens_log <- function(provider, tokens) {
2+
# TODO: probably should make this store in a data frame, but will tackle
3+
# when implementing token costs.
4+
5+
name <- paste0(provider@name, "/", provider@model)
6+
27
if (is.null(the$tokens)) {
38
the$tokens <- list()
49
}
@@ -11,7 +16,14 @@ tokens_log <- function(name, tokens) {
1116
invisible()
1217
}
1318

14-
tokens_reset <- function() {
19+
local_tokens <- function(frame = parent.frame()) {
20+
old <- the$tokens
21+
the$tokens <- NULL
22+
23+
defer(the$tokens <- old, env = frame)
24+
}
25+
26+
tokens_set <- function() {
1527
the$tokens <- NULL
1628
invisible()
1729
}

man/Provider.Rd

+16-2
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

tests/testthat/_snaps/chat.md

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
Code
44
chat
55
Output
6-
<Chat turns=3 tokens=15/5>
6+
<Chat OpenAI/gpt-4o turns=3 tokens=15/5>
77
-- system [0] ------------------------------------------------------------------
88
You're a helpful assistant that returns very minimal output
99
-- user [15] -------------------------------------------------------------------

tests/testthat/_snaps/tokens.md

+8
Original file line numberDiff line numberDiff line change
@@ -5,3 +5,11 @@
55
Message
66
x No recorded usage in this session
77

8+
# can retrieve and log tokens
9+
10+
Code
11+
token_usage()
12+
Output
13+
name input output
14+
1 testprovider/test 10 60
15+

0 commit comments

Comments
 (0)