Skip to content

Commit 605ef04

Browse files
authored
Compute token costs (#418)
Fixes #203
1 parent 159aa2b commit 605ef04

25 files changed

+387
-65
lines changed

Diff for: .Rbuildignore

+1
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,4 @@ _cache/
1313
^CRAN-SUBMISSION$
1414
^[\.]?air\.toml$
1515
^\.vscode$
16+
^data-raw$

Diff for: DESCRIPTION

+2
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ Description: Chat with large language models from a range of providers
1616
License: MIT + file LICENSE
1717
URL: https://ellmer.tidyverse.org, https://github.com/tidyverse/ellmer
1818
BugReports: https://github.com/tidyverse/ellmer/issues
19+
Depends:
20+
R (>= 4.1)
1921
Imports:
2022
cli,
2123
coro (>= 1.1.0),

Diff for: NAMESPACE

+2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
# Generated by roxygen2: do not edit by hand
22

3+
S3method(format,ellmer_dollars)
34
S3method(print,Chat)
5+
S3method(print,ellmer_dollars)
46
export(Content)
57
export(ContentImage)
68
export(ContentImageInline)

Diff for: NEWS.md

+7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
# ellmer (development version)
22

3+
* ellmer now tracks the cost of input and output tokens. The cost is displayed
4+
when you print a `Chat` object, in `tokens_usage()`, and with
5+
`Chat$get_cost()`. This is our best effort at computing the cost, but you
6+
should treat it as an estimate rather than the exact price. Unfortunately LLM APIs
7+
currently make it very hard to figure out exactly how much your queries are
8+
costing (#203).
9+
310
* `ContentToolResult` objects now include the error condition in the `error`
411
property when a tool call fails (#421, @gadenbuie).
512

Diff for: R/chat.R

+49-6
Original file line numberDiff line numberDiff line change
@@ -119,32 +119,67 @@ Chat <- R6::R6Class(
119119
assistant_turns <- keep(turns, function(x) x@role == "assistant")
120120

121121
n <- length(assistant_turns)
122-
tokens <- t(vapply(
122+
tokens_acc <- t(vapply(
123123
assistant_turns,
124124
function(turn) turn@tokens,
125125
double(2)
126126
))
127+
128+
tokens <- tokens_acc
127129
if (n > 1) {
128130
# Compute just the new tokens
129131
tokens[-1, 1] <- tokens[seq(2, n), 1] -
130132
(tokens[seq(1, n - 1), 1] + tokens[seq(1, n - 1), 2])
131133
}
132134
# collapse into a single vector
133135
tokens_v <- c(t(tokens))
136+
tokens_acc_v <- c(t(tokens_acc))
134137

135138
tokens_df <- data.frame(
136139
role = rep(c("user", "assistant"), times = n),
137-
tokens = tokens_v
140+
tokens = tokens_v,
141+
tokens_total = tokens_acc_v
138142
)
139143

140144
if (include_system_prompt && private$has_system_prompt()) {
141145
# How do we compute this?
142-
tokens_df <- rbind(data.frame(role = "system", tokens = 0), tokens_df)
146+
tokens_df <- rbind(
147+
data.frame(role = "system", tokens = 0, tokens_total = 0),
148+
tokens_df
149+
)
143150
}
144151

145152
tokens_df
146153
},
147154

155+
#' @description The cost of this chat
156+
#' @param include The default, `"all"`, gives the total cumulative cost
157+
#' of this chat. Alternatively, use `"last"` to get the cost of just the
158+
#' most recent turn.
159+
get_cost = function(include = c("all", "last")) {
160+
include <- arg_match(include)
161+
162+
turns <- self$get_turns(include_system_prompt = FALSE)
163+
assistant_turns <- keep(turns, function(x) x@role == "assistant")
164+
n <- length(assistant_turns)
165+
tokens <- t(vapply(
166+
assistant_turns,
167+
function(turn) turn@tokens,
168+
double(2)
169+
))
170+
171+
if (include == "last") {
172+
tokens <- tokens[nrow(tokens), , drop = FALSE]
173+
}
174+
175+
get_token_cost(
176+
private$provider@name,
177+
private$provider@model,
178+
input = sum(tokens[, 1]),
179+
output = sum(tokens[, 2])
180+
)
181+
},
182+
148183
#' @description The last turn returned by the assistant.
149184
#' @param role Optionally, specify a role to find the last turn with
150185
#' for the role.
@@ -645,14 +680,22 @@ print.Chat <- function(x, ...) {
645680
turns <- x$get_turns(include_system_prompt = TRUE)
646681

647682
tokens <- x$tokens(include_system_prompt = TRUE)
648-
tokens_user <- sum(tokens$tokens[tokens$role == "user"])
649-
tokens_assistant <- sum(tokens$tokens[tokens$role == "assistant"])
683+
684+
tokens_user <- sum(tokens$tokens_total[tokens$role == "user"])
685+
tokens_assistant <- sum(tokens$tokens_total[tokens$role == "assistant"])
686+
cost <- x$get_cost()
650687

651688
cat(paste_c(
652689
"<Chat",
653690
c(" ", provider@name, "/", provider@model),
654691
c(" turns=", length(turns)),
655-
c(" tokens=", tokens_user, "/", tokens_assistant),
692+
c(
693+
" tokens=",
694+
tokens_user,
695+
"/",
696+
tokens_assistant
697+
),
698+
if (!is.na(cost)) c(" ", format(cost)),
656699
">\n"
657700
))
658701

Diff for: R/provider-bedrock.R

+5-2
Original file line numberDiff line numberDiff line change
@@ -263,8 +263,11 @@ method(value_turn, ProviderAWSBedrock) <- function(
263263
}
264264
})
265265

266-
tokens <- c(result$usage$inputTokens, result$usage$outputTokens)
267-
tokens_log(provider, tokens)
266+
tokens <- tokens_log(
267+
provider,
268+
input = result$usage$inputTokens,
269+
output = result$usage$outputTokens
270+
)
268271

269272
Turn(result$output$message$role, contents, json = result, tokens = tokens)
270273
}

Diff for: R/provider-claude.R

+10-2
Original file line numberDiff line numberDiff line change
@@ -288,8 +288,11 @@ method(value_turn, ProviderAnthropic) <- function(
288288
}
289289
})
290290

291-
tokens <- c(result$usage$input_tokens, result$usage$output_tokens)
292-
tokens_log(provider, tokens)
291+
tokens <- tokens_log(
292+
provider,
293+
input = result$usage$input_tokens,
294+
output = result$usage$output_tokens
295+
)
293296

294297
Turn(result$role, contents, json = result, tokens = tokens)
295298
}
@@ -401,6 +404,11 @@ method(as_json, list(ProviderAnthropic, ContentThinking)) <- function(
401404
signature = x@extra$signature
402405
)
403406
}
407+
# Pricing ----------------------------------------------------------------------
408+
409+
method(standardise_model, ProviderAnthropic) <- function(provider, model) {
410+
gsub("-(latest|\\d{8})$", "", model)
411+
}
404412

405413
# Helpers ----------------------------------------------------------------
406414

Diff for: R/provider-gemini.R

+12-4
Original file line numberDiff line numberDiff line change
@@ -203,11 +203,11 @@ method(value_turn, ProviderGoogleGemini) <- function(
203203
})
204204
contents <- compact(contents)
205205
usage <- result$usageMetadata
206-
tokens <- c(
207-
usage$promptTokenCount %||% NA_integer_,
208-
usage$candidatesTokenCount %||% NA_integer_
206+
tokens <- tokens_log(
207+
provider,
208+
input = usage$promptTokenCount,
209+
output = usage$candidatesTokenCount
209210
)
210-
tokens_log(provider, tokens)
211211

212212
Turn("assistant", contents, json = result, tokens = tokens)
213213
}
@@ -562,3 +562,11 @@ default_google_credentials <- function(
562562
list(Authorization = paste("Bearer", token$credentials$access_token))
563563
})
564564
}
565+
566+
# Pricing ----------------------------------------------------------------------
567+
568+
method(standardise_model, ProviderGoogleGemini) <- function(provider, model) {
569+
# https://ai.google.dev/gemini-api/docs/models#model-versions
570+
# <model>-<generation>-<variation>-...
571+
gsub("^([^-]+-[^-]+-[^-]+).*$", "\\1", model)
572+
}

Diff for: R/provider-openai.R

+4-5
Original file line numberDiff line numberDiff line change
@@ -237,12 +237,11 @@ method(value_turn, ProviderOpenAI) <- function(
237237
})
238238
content <- c(content, calls)
239239
}
240-
tokens <- c(
241-
result$usage$prompt_tokens %||% NA_integer_,
242-
result$usage$completion_tokens %||% NA_integer_
240+
tokens <- tokens_log(
241+
provider,
242+
input = result$usage$prompt_tokens,
243+
output = result$usage$completion_tokens
243244
)
244-
tokens_log(provider, tokens)
245-
246245
Turn(message$role, content, json = result, tokens = tokens)
247246
}
248247

Diff for: R/provider-snowflake.R

+4-4
Original file line numberDiff line numberDiff line change
@@ -142,11 +142,11 @@ method(value_turn, ProviderSnowflakeCortex) <- function(
142142
) {
143143
deltas <- compact(sapply(result$choices, function(x) x$delta$content))
144144
content <- list(as_content(paste(deltas, collapse = "")))
145-
tokens <- c(
146-
result$usage$prompt_tokens %||% NA_integer_,
147-
result$usage$completion_tokens %||% NA_integer_
145+
tokens <- tokens_log(
146+
provider,
147+
input = result$usage$prompt_tokens,
148+
output = result$usage$completion_tokens
148149
)
149-
tokens_log(provider, tokens)
150150
Turn(
151151
# Snowflake's response format seems to omit the role.
152152
"assistant",

Diff for: R/provider.R

+14
Original file line numberDiff line numberDiff line change
@@ -113,3 +113,17 @@ method(as_json, list(Provider, class_list)) <- function(provider, x) {
113113
method(as_json, list(Provider, ContentJson)) <- function(provider, x) {
114114
as_json(provider, ContentText("<structured data/>"))
115115
}
116+
117+
# Pricing ---------------------------------------------------------------------
118+
119+
standardise_model <- new_generic(
120+
"standardise_model",
121+
"provider",
122+
function(provider, model) {
123+
S7_dispatch()
124+
}
125+
)
126+
127+
method(standardise_model, Provider) <- function(provider, model) {
128+
model
129+
}

Diff for: R/sysdata.rda

737 Bytes
Binary file not shown.

Diff for: R/tokens.R

+67-27
Original file line numberDiff line numberDiff line change
@@ -1,52 +1,92 @@
1-
tokens_log <- function(provider, tokens) {
2-
# TODO: probably should make this store in a data frame, but will tackle
3-
# when implementing token costs.
1+
on_load(
2+
the$tokens <- tokens_row(character(), character(), numeric(), numeric())
3+
)
44

5-
name <- paste0(provider@name, "/", provider@model)
5+
tokens_log <- function(provider, input = NULL, output = NULL) {
6+
input <- input %||% 0
7+
output <- output %||% 0
68

7-
if (is.null(the$tokens)) {
8-
the$tokens <- list()
9-
}
10-
if (is.null(the$tokens[[name]])) {
11-
the$tokens[[name]] <- c(0, 0)
9+
model <- standardise_model(provider, provider@model)
10+
11+
name <- function(provider, model) paste0(provider, "/", model)
12+
i <- tokens_match(provider@name, model, the$tokens$provider, the$tokens$model)
13+
14+
if (is.na(i)) {
15+
new_row <- tokens_row(provider@name, model, input, output)
16+
the$tokens <- rbind(the$tokens, new_row)
17+
} else {
18+
the$tokens$input[i] <- the$tokens$input[i] + input
19+
the$tokens$output[i] <- the$tokens$output[i] + output
1220
}
1321

14-
tokens[is.na(tokens)] <- 0
15-
the$tokens[[name]] <- the$tokens[[name]] + tokens
16-
invisible()
22+
# Returns value to be passed to Turn
23+
c(input, output)
24+
}
25+
26+
tokens_row <- function(provider, model, input, output) {
27+
data.frame(provider = provider, model = model, input = input, output = output)
1728
}
1829

30+
tokens_match <- function(
31+
provider_needle,
32+
model_needle,
33+
provider_haystack,
34+
model_haystack
35+
) {
36+
match(
37+
paste0(provider_needle, "/", model_needle),
38+
paste0(provider_haystack, "/", model_haystack)
39+
)
40+
}
41+
42+
1943
local_tokens <- function(frame = parent.frame()) {
2044
old <- the$tokens
21-
the$tokens <- NULL
45+
the$tokens <- tokens_row(character(), character(), numeric(), numeric())
2246

2347
defer(the$tokens <- old, env = frame)
2448
}
2549

26-
tokens_set <- function() {
27-
the$tokens <- NULL
28-
invisible()
29-
}
30-
3150
#' Report on token usage in the current session
3251
#'
3352
#' Call this function to find out the cumulative number of tokens that you
34-
#' have sent and recieved in the current session.
53+
#' have sent and recieved in the current session. The price will be shown
54+
#' if known.
3555
#'
3656
#' @export
3757
#' @return A data frame
3858
#' @examples
3959
#' token_usage()
4060
token_usage <- function() {
41-
if (is.null(the$tokens)) {
61+
if (nrow(the$tokens) == 0) {
4262
cli::cli_inform(c(x = "No recorded usage in this session"))
43-
return(invisible(
44-
data.frame(name = character(), input = numeric(), output = numeric())
45-
))
63+
return(invisible(the$tokens))
4664
}
4765

48-
rows <- map2(names(the$tokens), the$tokens, function(name, tokens) {
49-
data.frame(name = name, input = tokens[[1]], output = tokens[[2]])
50-
})
51-
do.call("rbind", rows)
66+
out <- the$tokens
67+
out$price <- get_token_cost(out$provider, out$model, out$input, out$output)
68+
out
69+
}
70+
71+
# Cost ----------------------------------------------------------------------
72+
73+
get_token_cost <- function(provider, model, input, output) {
74+
idx <- tokens_match(provider, model, prices$provider, prices$model)
75+
76+
input_price <- input * prices$input[idx] / 1e6
77+
output_price <- output * prices$output[idx] / 1e6
78+
dollars(input_price + output_price)
79+
}
80+
81+
dollars <- function(x) {
82+
structure(x, class = c("ellmer_dollars", "numeric"))
83+
}
84+
#' @export
85+
format.ellmer_dollars <- function(x, ...) {
86+
paste0(ifelse(is.na(x), "", "$"), format(unclass(round(x, 2)), nsmall = 2))
87+
}
88+
#' @export
89+
print.ellmer_dollars <- function(x, ...) {
90+
print(format(x), quote = FALSE)
91+
invisible(x)
5292
}

Diff for: R/zzz.R

+3
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33
S7::methods_register()
44
}
55

6+
# Work around S7 bug
7+
rm(format)
8+
69
# enable usage of <S7_object>@name in package code
710
#' @rawNamespace if (getRversion() < "4.3.0") importFrom("S7", "@")
811
NULL

0 commit comments

Comments
 (0)