diff --git a/NEWS.md b/NEWS.md index 0f6613ab..3deac911 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # ellmer (development version) +* New `Chat$get_cost_details()` to get turn-by-token token usage + costs (#812). * Updated pricing data (#790). * The following deprecated functions/arguments/methods have now been removed: * `Chat$extract_data()` -> `chat$chat_structured()` (0.2.0) diff --git a/R/chat.R b/R/chat.R index 83644c86..e659f6df 100644 --- a/R/chat.R +++ b/R/chat.R @@ -70,13 +70,20 @@ Chat <- R6::R6Class( #' @description Add a pair of turns to the chat. #' @param user The user [Turn]. - #' @param system The system [Turn]. - add_turn = function(user, system) { + #' @param assistant The system [Turn]. + add_turn = function(user, assistant) { check_turn(user) - check_turn(system) + check_turn(assistant) + + tokens_log( + private$provider, + # TODO: store better representation in Turn object + exec(tokens, !!!as.list(assistant@tokens)), + assistant@cost + ) private$.turns[[length(private$.turns) + 1]] <- user - private$.turns[[length(private$.turns) + 1]] <- system + private$.turns[[length(private$.turns) + 1]] <- assistant invisible(self) }, @@ -120,15 +127,11 @@ Chat <- R6::R6Class( #' @param include_system_prompt Whether to include the system prompt in #' the turns (if any exists). get_tokens = function(include_system_prompt = FALSE) { - turns <- self$get_turns(include_system_prompt = FALSE) + turns <- self$get_turns() assistant_turns <- keep(turns, function(x) x@role == "assistant") n <- length(assistant_turns) - tokens_acc <- t(vapply( - assistant_turns, - function(turn) turn@tokens, - double(3) - )) + tokens_acc <- map_tokens(assistant_turns, \(turn) turn@tokens) # Combine counts for input tokens (cached and uncached) tokens_acc[, 1] <- tokens_acc[, 1] + tokens_acc[, 3] # Then drop cached tokens counts @@ -177,24 +180,29 @@ Chat <- R6::R6Class( get_cost = function(include = c("all", "last")) { include <- arg_match(include) - turns <- self$get_turns(include_system_prompt = FALSE) + turns <- self$get_turns() assistant_turns <- keep(turns, function(x) x@role == "assistant") - n <- length(assistant_turns) - tokens <- t(vapply( - assistant_turns, - function(turn) turn@tokens, - double(3) - )) + + if (length(assistant_turns) == 0) { + return(dollars(0)) + } if (include == "last") { - tokens <- tokens[nrow(tokens), , drop = FALSE] + cost <- assistant_turns[[length(assistant_turns)]]@cost + } else { + cost <- sum(map_dbl(assistant_turns, \(turn) turn@cost)) } - private$compute_cost( - input = sum(tokens[, 1]), - output = sum(tokens[, 2]), - cached_input = sum(tokens[, 3]) - ) + dollars(cost) + }, + + #' @description The tokens for each user-assistant turn of this chat. + get_cost_details = function() { + turns <- self$get_turns(include_system_prompt = FALSE) + assistant_turns <- keep(turns, function(x) x@role == "assistant") + tokens <- as.data.frame(map_tokens(assistant_turns, \(turn) turn@tokens)) + tokens$cost <- dollars(map_dbl(assistant_turns, \(turn) turn@cost)) + tokens }, #' @description The last turn returned by the assistant. @@ -764,17 +772,6 @@ Chat <- R6::R6Class( has_system_prompt = function() { length(private$.turns) > 0 && private$.turns[[1]]@role == "system" - }, - - compute_cost = function(input, output, cached_input) { - get_token_cost( - private$provider@name, - private$provider@model, - variant = "", - input = input, - output = output, - cached_input = cached_input - ) } ) ) diff --git a/R/parallel-chat.R b/R/parallel-chat.R index 7a536184..1650cd76 100644 --- a/R/parallel-chat.R +++ b/R/parallel-chat.R @@ -270,29 +270,19 @@ multi_convert <- function( out$.error <- errors } - if (include_tokens || include_cost) { - tokens <- t(vapply( - turns, - \(turn) if (turn_failed(turn)) c(0L, 0L, 0L) else turn@tokens, - integer(3) - )) - - if (include_tokens) { - out$input_tokens <- tokens[, 1] - out$output_tokens <- tokens[, 2] - out$cached_input_tokens <- tokens[, 3] - } - - if (include_cost) { - out$cost <- get_token_cost( - provider@name, - provider@model, - variant = "", - input = tokens[, 1], - output = tokens[, 2], - cached_input = tokens[, 3] - ) - } + if (include_tokens) { + tokens <- map_tokens(turns, \(turn) { + if (turn_failed(turn)) c(0L, 0L, 0L) else turn@tokens + }) + out$input_tokens <- tokens[, 1] + out$output_tokens <- tokens[, 2] + out$cached_input_tokens <- tokens[, 3] + } + + if (include_cost) { + out$cost <- map_dbl(turns, \(turn) { + if (turn_failed(turn)) 0 else turn@cost + }) } } out diff --git a/R/provider-anthropic.R b/R/provider-anthropic.R index d1f4fae3..9c30ba9e 100644 --- a/R/provider-anthropic.R +++ b/R/provider-anthropic.R @@ -293,8 +293,8 @@ method(value_turn, ProviderAnthropic) <- function( }) tokens <- value_tokens(provider, result) - tokens_log(provider, tokens) - assistant_turn(contents, json = result, tokens = unlist(tokens)) + cost <- get_token_cost(provider, tokens) + assistant_turn(contents, json = result, tokens = unlist(tokens), cost = cost) } # ellmer -> Claude -------------------------------------------------------------- diff --git a/R/provider-aws.R b/R/provider-aws.R index fc66c972..becc61a9 100644 --- a/R/provider-aws.R +++ b/R/provider-aws.R @@ -358,7 +358,8 @@ method(value_turn, ProviderAWSBedrock) <- function( }) tokens <- value_tokens(provider, result) - assistant_turn(contents, json = result, tokens = unlist(tokens)) + cost <- get_token_cost(provider, tokens) + assistant_turn(contents, json = result, tokens = unlist(tokens), cost = cost) } # ellmer -> Bedrock ------------------------------------------------------------- diff --git a/R/provider-google.R b/R/provider-google.R index 8d319708..20f6f2de 100644 --- a/R/provider-google.R +++ b/R/provider-google.R @@ -307,7 +307,8 @@ method(value_turn, ProviderGoogleGemini) <- function( }) contents <- compact(contents) tokens <- value_tokens(provider, result) - assistant_turn(contents, json = result, tokens = unlist(tokens)) + cost <- get_token_cost(provider, tokens) + assistant_turn(contents, json = result, tokens = unlist(tokens), cost = cost) } # ellmer -> Gemini -------------------------------------------------------------- diff --git a/R/provider-openai-responses.R b/R/provider-openai-responses.R index 67ff107d..a88e4296 100644 --- a/R/provider-openai-responses.R +++ b/R/provider-openai-responses.R @@ -225,8 +225,13 @@ method(value_turn, ProviderOpenAIResponses) <- function( }) tokens <- value_tokens(provider, result) - tokens_log(provider, tokens) - assistant_turn(contents = contents, json = result, tokens = unlist(tokens)) + cost <- get_token_cost(provider, tokens) + assistant_turn( + contents = contents, + json = result, + tokens = unlist(tokens), + cost = cost + ) } # ellmer -> OpenAI -------------------------------------------------------------- diff --git a/R/provider-openai.R b/R/provider-openai.R index 1eb93766..5df1a4cf 100644 --- a/R/provider-openai.R +++ b/R/provider-openai.R @@ -266,8 +266,8 @@ method(value_turn, ProviderOpenAI) <- function( } tokens <- value_tokens(provider, result) - tokens_log(provider, tokens) - assistant_turn(content, json = result, tokens = unlist(tokens)) + cost <- get_token_cost(provider, tokens) + assistant_turn(content, json = result, tokens = unlist(tokens), cost = cost) } # ellmer -> OpenAI -------------------------------------------------------------- diff --git a/R/provider-snowflake.R b/R/provider-snowflake.R index b2524912..d43a9903 100644 --- a/R/provider-snowflake.R +++ b/R/provider-snowflake.R @@ -260,8 +260,8 @@ method(value_turn, ProviderSnowflakeCortex) <- function( } }) tokens <- value_tokens(provider, result) - tokens_log(provider, tokens) - assistant_turn(contents, json = result, tokens = unlist(tokens)) + cost <- get_token_cost(provider, tokens) + assistant_turn(contents, json = result, tokens = unlist(tokens), cost = cost) } # ellmer -> Snowflake -------------------------------------------------------- diff --git a/R/tokens.R b/R/tokens.R index ed967416..155aaff4 100644 --- a/R/tokens.R +++ b/R/tokens.R @@ -14,24 +14,29 @@ tokens <- function(input = 0, output = 0, cached_input = 0) { ) } -tokens_log <- function(provider, tokens, variant = "") { +map_tokens <- function(x, f, ...) { + out <- t(vapply(x, f, double(3))) + colnames(out) <- c("input", "output", "cached_input") + out +} + +tokens_log <- function(provider, tokens, cost) { i <- vctrs::vec_match( data.frame( provider = provider@name, - model = provider@model, - variant = variant + model = provider@model ), - the$tokens[c("provider", "model", "variant")] + the$tokens[c("provider", "model")] ) if (is.na(i)) { new_row <- tokens_row( provider@name, provider@model, - variant, tokens$input, tokens$output, - tokens$cached_input + tokens$cached_input, + cost ) the$tokens <- rbind(the$tokens, new_row) } else { @@ -39,6 +44,7 @@ tokens_log <- function(provider, tokens, variant = "") { the$tokens$output[i] <- the$tokens$output[i] + tokens$output the$tokens$cached_input[i] <- the$tokens$cached_input[i] + tokens$cached_input + the$tokens$price[i] <- the$tokens$price[i] + cost } invisible() @@ -47,18 +53,18 @@ tokens_log <- function(provider, tokens, variant = "") { tokens_row <- function( provider = character(0), model = character(0), - variant = character(0), input = numeric(0), output = numeric(0), - cached_input = numeric(0) + cached_input = numeric(0), + price = numeric(0) ) { data.frame( provider = provider, model = model, - variant = variant, input = input, output = output, - cached_input = cached_input + cached_input = cached_input, + price = price ) } @@ -85,16 +91,7 @@ token_usage <- function() { return(invisible(the$tokens)) } - out <- the$tokens - out$price <- get_token_cost( - out$provider, - out$model, - out$variant, - out$input, - out$output, - out$cached_input - ) - out + the$tokens } # Cost ---------------------------------------------------------------------- @@ -104,15 +101,12 @@ has_cost <- function(provider, model) { vctrs::vec_in(needle, prices[c("provider", "model")]) } -get_token_cost <- function( - provider, - model, - variant, - input = 0, - output = 0, - cached_input = 0 -) { - needle <- data.frame(provider = provider, model = model, variant = variant) +get_token_cost <- function(provider, tokens, variant = "") { + needle <- data.frame( + provider = provider@name, + model = provider@model, + variant = variant + ) idx <- vctrs::vec_match(needle, prices[c("provider", "model", "variant")]) if (any(is.na(idx))) { @@ -126,9 +120,9 @@ get_token_cost <- function( ) } - input_price <- input * prices$input[idx] / 1e6 - output_price <- output * prices$output[idx] / 1e6 - cached_input_price <- cached_input * prices$cached_input[idx] / 1e6 + input_price <- tokens$input * prices$input[idx] / 1e6 + output_price <- tokens$output * prices$output[idx] / 1e6 + cached_input_price <- tokens$cached_input * prices$cached_input[idx] / 1e6 dollars(input_price + output_price + cached_input_price) } diff --git a/R/turns.R b/R/turns.R index 3b80ba7f..5a23c9e7 100644 --- a/R/turns.R +++ b/R/turns.R @@ -24,9 +24,12 @@ NULL #' This is useful if there's information returned by the provider that ellmer #' doesn't otherwise expose. #' @param tokens A numeric vector of length 2 representing the number of -#' input and output tokens (respectively) used in this turn. Currently -#' only recorded for assistant turns. -#' @param duration The duration of the request in seconds. `NA` for user turns, +#' input and output tokens (respectively) used in this turn. +#' Only meaningful for assistant turns. +#' @param cost The cost of the turn in dollars. Only meaningful for assistant +#' turns. +#' @param duration The duration of the request in seconds. +#' Only meaning for assistant turns. #' numeric for assistant turns. #' @export #' @return An S7 `Turn` object @@ -47,6 +50,10 @@ Turn <- new_class( } } ), + cost = new_property( + class = class_numeric, + default = NA_real_ + ), duration = new_property( class_numeric, default = NA_real_ @@ -61,6 +68,7 @@ Turn <- new_class( contents = list(), json = list(), tokens = c(0, 0, 0), + cost = NA_real_, duration = NA_real_ ) { if (is.character(contents)) { @@ -72,7 +80,8 @@ Turn <- new_class( contents = contents, json = json, tokens = tokens, - duration = duration + duration = duration, + cost = cost ) } ) diff --git a/man/Chat.Rd b/man/Chat.Rd index 512436c9..a0eaacad 100644 --- a/man/Chat.Rd +++ b/man/Chat.Rd @@ -35,6 +35,7 @@ chat$chat("Tell me a funny joke") \item \href{#method-Chat-set_system_prompt}{\code{Chat$set_system_prompt()}} \item \href{#method-Chat-get_tokens}{\code{Chat$get_tokens()}} \item \href{#method-Chat-get_cost}{\code{Chat$get_cost()}} +\item \href{#method-Chat-get_cost_details}{\code{Chat$get_cost_details()}} \item \href{#method-Chat-last_turn}{\code{Chat$last_turn()}} \item \href{#method-Chat-chat}{\code{Chat$chat()}} \item \href{#method-Chat-chat_structured}{\code{Chat$chat_structured()}} @@ -123,7 +124,7 @@ Replace existing turns with a new list. \subsection{Method \code{add_turn()}}{ Add a pair of turns to the chat. \subsection{Usage}{ -\if{html}{\out{