diff --git a/NEWS.md b/NEWS.md index dc110a12..8fb2ff4b 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,8 @@ # ellmer (development version) +* `chat_gemini()` can now handle responses that include citation metadata + (#358). + * `chat_` functions no longer take a turns object, instead use `set_turns()` (#427). diff --git a/R/provider-gemini.R b/R/provider-gemini.R index 3e2e68cd..f2c7216f 100644 --- a/R/provider-gemini.R +++ b/R/provider-gemini.R @@ -388,6 +388,12 @@ merge_optional <- function(merge_func) { merge_objects <- function(...) { spec <- list(...) function(left, right, path = NULL) { + if (is.null(left)) { + return(right) + } else if (is.null(right)) { + return(left) + } + # cat(paste(collapse = "", path), "\n") stopifnot(is.list(left), is.list(right), all(nzchar(names(spec)))) mapply( @@ -469,6 +475,7 @@ merge_parts <- function() { } # Put it all together... +# https://ai.google.dev/api/generate-content#v1beta.GenerateContentResponse merge_gemini_chunks <- merge_objects( candidates = merge_candidate_lists( content = merge_objects( diff --git a/tests/testthat/test-provider-gemini.R b/tests/testthat/test-provider-gemini.R index 1e238317..4a952935 100644 --- a/tests/testthat/test-provider-gemini.R +++ b/tests/testthat/test-provider-gemini.R @@ -105,3 +105,32 @@ test_that("strips suffix from model name", { "gemini-2.0-pro" ) }) + +test_that("can handle citations", { + # based on "Write me a 5-paragraph essay on the history of the tidyverse." + messages <- c( + '{"candidates": [{"content": {"parts": [{"text": "a"}]}, "role": "model"}]}', + '{"candidates": [{ + "content": {"parts": [{"text": "a"}]}, + "role": "model", + "citationMetadata": { + "citationSources": [ + { + "startIndex": 1, + "endIndex": 2, + "uri": "https://example.com", + "license": "" + } + ] + } + }]}' + ) + chunks <- lapply(messages, jsonlite::parse_json) + + out <- merge_gemini_chunks(chunks[[1]], chunks[[2]]) + source <- out$candidates[[1]]$citationMetadata$citationSources[[1]] + expect_equal(source$startIndex, 1) + expect_equal(source$endIndex, 2) + expect_equal(source$uri, "https://example.com") + expect_equal(source$license, "") +})