Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# ellmer (development version)

* New `Chat$get_cost_details()` to get turn-by-token token usage + costs (#812).
* Updated pricing data (#790).
* The following deprecated functions/arguments/methods have now been removed:
* `Chat$extract_data()` -> `chat$chat_structured()` (0.2.0)
Expand Down
65 changes: 31 additions & 34 deletions R/chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,20 @@ Chat <- R6::R6Class(

#' @description Add a pair of turns to the chat.
#' @param user The user [Turn].
#' @param system The system [Turn].
add_turn = function(user, system) {
#' @param assistant The system [Turn].
add_turn = function(user, assistant) {
check_turn(user)
check_turn(system)
check_turn(assistant)

tokens_log(
private$provider,
# TODO: store better representation in Turn object
exec(tokens, !!!as.list(assistant@tokens)),
assistant@cost
)

private$.turns[[length(private$.turns) + 1]] <- user
private$.turns[[length(private$.turns) + 1]] <- system
private$.turns[[length(private$.turns) + 1]] <- assistant
invisible(self)
},

Expand Down Expand Up @@ -120,15 +127,11 @@ Chat <- R6::R6Class(
#' @param include_system_prompt Whether to include the system prompt in
#' the turns (if any exists).
get_tokens = function(include_system_prompt = FALSE) {
turns <- self$get_turns(include_system_prompt = FALSE)
turns <- self$get_turns()
assistant_turns <- keep(turns, function(x) x@role == "assistant")

n <- length(assistant_turns)
tokens_acc <- t(vapply(
assistant_turns,
function(turn) turn@tokens,
double(3)
))
tokens_acc <- map_tokens(assistant_turns, \(turn) turn@tokens)
# Combine counts for input tokens (cached and uncached)
tokens_acc[, 1] <- tokens_acc[, 1] + tokens_acc[, 3]
# Then drop cached tokens counts
Expand Down Expand Up @@ -177,24 +180,29 @@ Chat <- R6::R6Class(
get_cost = function(include = c("all", "last")) {
include <- arg_match(include)

turns <- self$get_turns(include_system_prompt = FALSE)
turns <- self$get_turns()
assistant_turns <- keep(turns, function(x) x@role == "assistant")
n <- length(assistant_turns)
tokens <- t(vapply(
assistant_turns,
function(turn) turn@tokens,
double(3)
))

if (length(assistant_turns) == 0) {
return(dollars(0))
}

if (include == "last") {
tokens <- tokens[nrow(tokens), , drop = FALSE]
cost <- assistant_turns[[length(assistant_turns)]]@cost
} else {
cost <- sum(map_dbl(assistant_turns, \(turn) turn@cost))
}

private$compute_cost(
input = sum(tokens[, 1]),
output = sum(tokens[, 2]),
cached_input = sum(tokens[, 3])
)
dollars(cost)
},

#' @description The tokens for each user-assistant turn of this chat.
get_cost_details = function() {
turns <- self$get_turns(include_system_prompt = FALSE)
assistant_turns <- keep(turns, function(x) x@role == "assistant")
tokens <- as.data.frame(map_tokens(assistant_turns, \(turn) turn@tokens))
tokens$cost <- dollars(map_dbl(assistant_turns, \(turn) turn@cost))
tokens
},

#' @description The last turn returned by the assistant.
Expand Down Expand Up @@ -764,17 +772,6 @@ Chat <- R6::R6Class(

has_system_prompt = function() {
length(private$.turns) > 0 && private$.turns[[1]]@role == "system"
},

compute_cost = function(input, output, cached_input) {
get_token_cost(
private$provider@name,
private$provider@model,
variant = "",
input = input,
output = output,
cached_input = cached_input
)
}
)
)
Expand Down
36 changes: 13 additions & 23 deletions R/parallel-chat.R
Original file line number Diff line number Diff line change
Expand Up @@ -270,29 +270,19 @@ multi_convert <- function(
out$.error <- errors
}

if (include_tokens || include_cost) {
tokens <- t(vapply(
turns,
\(turn) if (turn_failed(turn)) c(0L, 0L, 0L) else turn@tokens,
integer(3)
))

if (include_tokens) {
out$input_tokens <- tokens[, 1]
out$output_tokens <- tokens[, 2]
out$cached_input_tokens <- tokens[, 3]
}

if (include_cost) {
out$cost <- get_token_cost(
provider@name,
provider@model,
variant = "",
input = tokens[, 1],
output = tokens[, 2],
cached_input = tokens[, 3]
)
}
if (include_tokens) {
tokens <- map_tokens(turns, \(turn) {
if (turn_failed(turn)) c(0L, 0L, 0L) else turn@tokens
})
out$input_tokens <- tokens[, 1]
out$output_tokens <- tokens[, 2]
out$cached_input_tokens <- tokens[, 3]
}

if (include_cost) {
out$cost <- map_dbl(turns, \(turn) {
if (turn_failed(turn)) 0 else turn@cost
})
}
}
out
Expand Down
4 changes: 2 additions & 2 deletions R/provider-anthropic.R
Original file line number Diff line number Diff line change
Expand Up @@ -293,8 +293,8 @@ method(value_turn, ProviderAnthropic) <- function(
})

tokens <- value_tokens(provider, result)
tokens_log(provider, tokens)
assistant_turn(contents, json = result, tokens = unlist(tokens))
cost <- get_token_cost(provider, tokens)
assistant_turn(contents, json = result, tokens = unlist(tokens), cost = cost)
}

# ellmer -> Claude --------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion R/provider-aws.R
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ method(value_turn, ProviderAWSBedrock) <- function(
})

tokens <- value_tokens(provider, result)
assistant_turn(contents, json = result, tokens = unlist(tokens))
cost <- get_token_cost(provider, tokens)
assistant_turn(contents, json = result, tokens = unlist(tokens), cost = cost)
}

# ellmer -> Bedrock -------------------------------------------------------------
Expand Down
3 changes: 2 additions & 1 deletion R/provider-google.R
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,8 @@ method(value_turn, ProviderGoogleGemini) <- function(
})
contents <- compact(contents)
tokens <- value_tokens(provider, result)
assistant_turn(contents, json = result, tokens = unlist(tokens))
cost <- get_token_cost(provider, tokens)
assistant_turn(contents, json = result, tokens = unlist(tokens), cost = cost)
}

# ellmer -> Gemini --------------------------------------------------------------
Expand Down
9 changes: 7 additions & 2 deletions R/provider-openai-responses.R
Original file line number Diff line number Diff line change
Expand Up @@ -225,8 +225,13 @@ method(value_turn, ProviderOpenAIResponses) <- function(
})

tokens <- value_tokens(provider, result)
tokens_log(provider, tokens)
assistant_turn(contents = contents, json = result, tokens = unlist(tokens))
cost <- get_token_cost(provider, tokens)
assistant_turn(
contents = contents,
json = result,
tokens = unlist(tokens),
cost = cost
)
}

# ellmer -> OpenAI --------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions R/provider-openai.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,8 +266,8 @@ method(value_turn, ProviderOpenAI) <- function(
}

tokens <- value_tokens(provider, result)
tokens_log(provider, tokens)
assistant_turn(content, json = result, tokens = unlist(tokens))
cost <- get_token_cost(provider, tokens)
assistant_turn(content, json = result, tokens = unlist(tokens), cost = cost)
}

# ellmer -> OpenAI --------------------------------------------------------------
Expand Down
4 changes: 2 additions & 2 deletions R/provider-snowflake.R
Original file line number Diff line number Diff line change
Expand Up @@ -260,8 +260,8 @@ method(value_turn, ProviderSnowflakeCortex) <- function(
}
})
tokens <- value_tokens(provider, result)
tokens_log(provider, tokens)
assistant_turn(contents, json = result, tokens = unlist(tokens))
cost <- get_token_cost(provider, tokens)
assistant_turn(contents, json = result, tokens = unlist(tokens), cost = cost)
}

# ellmer -> Snowflake --------------------------------------------------------
Expand Down
58 changes: 26 additions & 32 deletions R/tokens.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,31 +14,37 @@ tokens <- function(input = 0, output = 0, cached_input = 0) {
)
}

tokens_log <- function(provider, tokens, variant = "") {
map_tokens <- function(x, f, ...) {
out <- t(vapply(x, f, double(3)))
colnames(out) <- c("input", "output", "cached_input")
out
}

tokens_log <- function(provider, tokens, cost) {
i <- vctrs::vec_match(
data.frame(
provider = provider@name,
model = provider@model,
variant = variant
model = provider@model
),
the$tokens[c("provider", "model", "variant")]
the$tokens[c("provider", "model")]
)

if (is.na(i)) {
new_row <- tokens_row(
provider@name,
provider@model,
variant,
tokens$input,
tokens$output,
tokens$cached_input
tokens$cached_input,
cost
)
the$tokens <- rbind(the$tokens, new_row)
} else {
the$tokens$input[i] <- the$tokens$input[i] + tokens$input
the$tokens$output[i] <- the$tokens$output[i] + tokens$output
the$tokens$cached_input[i] <- the$tokens$cached_input[i] +
tokens$cached_input
the$tokens$price[i] <- the$tokens$price[i] + cost
}

invisible()
Expand All @@ -47,18 +53,18 @@ tokens_log <- function(provider, tokens, variant = "") {
tokens_row <- function(
provider = character(0),
model = character(0),
variant = character(0),
input = numeric(0),
output = numeric(0),
cached_input = numeric(0)
cached_input = numeric(0),
price = numeric(0)
) {
data.frame(
provider = provider,
model = model,
variant = variant,
input = input,
output = output,
cached_input = cached_input
cached_input = cached_input,
price = price
)
}

Expand All @@ -85,16 +91,7 @@ token_usage <- function() {
return(invisible(the$tokens))
}

out <- the$tokens
out$price <- get_token_cost(
out$provider,
out$model,
out$variant,
out$input,
out$output,
out$cached_input
)
out
the$tokens
}

# Cost ----------------------------------------------------------------------
Expand All @@ -104,15 +101,12 @@ has_cost <- function(provider, model) {
vctrs::vec_in(needle, prices[c("provider", "model")])
}

get_token_cost <- function(
provider,
model,
variant,
input = 0,
output = 0,
cached_input = 0
) {
needle <- data.frame(provider = provider, model = model, variant = variant)
get_token_cost <- function(provider, tokens, variant = "") {
needle <- data.frame(
provider = provider@name,
model = provider@model,
variant = variant
)
idx <- vctrs::vec_match(needle, prices[c("provider", "model", "variant")])

if (any(is.na(idx))) {
Expand All @@ -126,9 +120,9 @@ get_token_cost <- function(
)
}

input_price <- input * prices$input[idx] / 1e6
output_price <- output * prices$output[idx] / 1e6
cached_input_price <- cached_input * prices$cached_input[idx] / 1e6
input_price <- tokens$input * prices$input[idx] / 1e6
output_price <- tokens$output * prices$output[idx] / 1e6
cached_input_price <- tokens$cached_input * prices$cached_input[idx] / 1e6

dollars(input_price + output_price + cached_input_price)
}
Expand Down
17 changes: 13 additions & 4 deletions R/turns.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,12 @@ NULL
#' This is useful if there's information returned by the provider that ellmer
#' doesn't otherwise expose.
#' @param tokens A numeric vector of length 2 representing the number of
#' input and output tokens (respectively) used in this turn. Currently
#' only recorded for assistant turns.
#' @param duration The duration of the request in seconds. `NA` for user turns,
#' input and output tokens (respectively) used in this turn.
#' Only meaningful for assistant turns.
#' @param cost The cost of the turn in dollars. Only meaningful for assistant
#' turns.
#' @param duration The duration of the request in seconds.
#' Only meaning for assistant turns.
#' numeric for assistant turns.
#' @export
#' @return An S7 `Turn` object
Expand All @@ -47,6 +50,10 @@ Turn <- new_class(
}
}
),
cost = new_property(
class = class_numeric,
default = NA_real_
),
duration = new_property(
class_numeric,
default = NA_real_
Expand All @@ -61,6 +68,7 @@ Turn <- new_class(
contents = list(),
json = list(),
tokens = c(0, 0, 0),
cost = NA_real_,
duration = NA_real_
) {
if (is.character(contents)) {
Expand All @@ -72,7 +80,8 @@ Turn <- new_class(
contents = contents,
json = json,
tokens = tokens,
duration = duration
duration = duration,
cost = cost
)
}
)
Expand Down
Loading