diff --git a/java/com/google/dotprompt/parser/BUILD.bazel b/java/com/google/dotprompt/parser/BUILD.bazel new file mode 100644 index 000000000..e2089b2a0 --- /dev/null +++ b/java/com/google/dotprompt/parser/BUILD.bazel @@ -0,0 +1,56 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# SPDX-License-Identifier: Apache-2.0 + +load("@rules_java//java:defs.bzl", "java_library", "java_test") + +java_library( + name = "parser", + srcs = [ + "Parser.java", + "Picoschema.java", + ], + visibility = ["//visibility:public"], + deps = [ + "//java/com/google/dotprompt/models", + "//java/com/google/dotprompt/resolvers", + "@maven//:com_fasterxml_jackson_core_jackson_databind", + "@maven//:com_fasterxml_jackson_dataformat_jackson_dataformat_yaml", + ], +) + +java_test( + name = "ParserTest", + srcs = ["ParserTest.java"], + test_class = "com.google.dotprompt.parser.ParserTest", + deps = [ + ":parser", + "//java/com/google/dotprompt/models", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + ], +) + +java_test( + name = "PicoschemaTest", + srcs = ["PicoschemaTest.java"], + test_class = "com.google.dotprompt.parser.PicoschemaTest", + deps = [ + ":parser", + "//java/com/google/dotprompt/resolvers", + "@maven//:com_google_truth_truth", + "@maven//:junit_junit", + ], +) diff --git a/java/com/google/dotprompt/parser/Parser.java b/java/com/google/dotprompt/parser/Parser.java new file mode 100644 index 000000000..7e82e94fb --- /dev/null +++ b/java/com/google/dotprompt/parser/Parser.java @@ -0,0 +1,523 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.dotprompt.parser; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.dataformat.yaml.YAMLFactory; +import com.google.dotprompt.models.DataArgument; +import com.google.dotprompt.models.MediaContent; +import com.google.dotprompt.models.MediaPart; +import com.google.dotprompt.models.Message; +import com.google.dotprompt.models.Part; +import com.google.dotprompt.models.PendingPart; +import com.google.dotprompt.models.Prompt; +import com.google.dotprompt.models.Role; +import com.google.dotprompt.models.TextPart; +import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** + * Parses Dotprompt files into Prompt objects and rendered templates into messages. + * + *

This class handles: + * + *

+ */ +public class Parser { + + /** Prefix for role markers in the template. */ + public static final String ROLE_MARKER_PREFIX = "<<Matches a YAML frontmatter block between "---" markers. Handles different line endings + * (CRLF, LF, CR) and optional trailing whitespace on the marker lines. + */ + private static final Pattern FRONTMATTER_PATTERN = + Pattern.compile("(?ms)^\\s*---[ \\t]*[\\r\\n]+(.*?)^[ \\t]*---[ \\t]*[\\r\\n]+"); + + /** + * Pattern to match role and history markers. + * + *

Examples: {@code <<>>}, {@code <<>>} + */ + public static final Pattern ROLE_AND_HISTORY_MARKER_PATTERN = + Pattern.compile("(<<>>"); + + /** + * Pattern to match media and section markers. + * + *

Examples: {@code <<>>}, {@code <<>>} + */ + public static final Pattern MEDIA_AND_SECTION_MARKER_PATTERN = + Pattern.compile("(<<>>"); + + /** ObjectMapper for parsing YAML frontmatter. */ + private static final ObjectMapper mapper = new ObjectMapper(new YAMLFactory()); + + /** Reserved metadata keywords that are handled specially, not moved to ext. */ + public static final Set RESERVED_METADATA_KEYWORDS = + Set.of( + "config", + "description", + "ext", + "input", + "model", + "name", + "output", + "raw", + "toolDefs", + "tools", + "variant", + "version"); + + /** + * Parses a Dotprompt template string into a Prompt object. + * + * @param content The raw string content of the prompt file (including frontmatter). + * @return The parsed Prompt object containing the template and configuration. + * @throws IOException If parsing the YAML frontmatter fails. + */ + public static Prompt parse(String content) throws IOException { + if (content == null || content.trim().isEmpty()) { + return new Prompt("", Map.of()); + } + + Matcher matcher = FRONTMATTER_PATTERN.matcher(content); + if (matcher.find()) { + String yaml = matcher.group(1); + String template = content.substring(matcher.end()); + + Map config = new HashMap<>(); + if (yaml != null && !yaml.trim().isEmpty()) { + try { + @SuppressWarnings("unchecked") + Map rawConfig = mapper.readValue(yaml, Map.class); + config = expandNamespacedKeys(rawConfig); + config.put("raw", rawConfig); + } catch (IOException e) { + throw e; + } + } + return new Prompt(template, config); + } else { + return new Prompt(content, Map.of()); + } + } + + /** + * Splits a string by a regex pattern while filtering out empty/whitespace-only pieces. + * + * @param source The source string to split. + * @param pattern The pattern to use for splitting. + * @return A list of non-empty string pieces. + */ + public static List splitByRegex(String source, Pattern pattern) { + if (source == null || source.isEmpty()) { + return List.of(); + } + + List result = new ArrayList<>(); + Matcher matcher = pattern.matcher(source); + int lastEnd = 0; + + while (matcher.find()) { + // Add text before the match + if (matcher.start() > lastEnd) { + String beforeMatch = source.substring(lastEnd, matcher.start()); + if (!beforeMatch.trim().isEmpty()) { + result.add(beforeMatch); + } + } + // Add the captured group (without the closing >>>) + String captured = matcher.group(1); + if (captured != null && !captured.trim().isEmpty()) { + result.add(captured); + } + lastEnd = matcher.end(); + } + + // Add remaining text after last match + if (lastEnd < source.length()) { + String remaining = source.substring(lastEnd); + if (!remaining.trim().isEmpty()) { + result.add(remaining); + } + } + + return result; + } + + /** + * Splits a rendered template string by role and history markers. + * + * @param renderedString The template string to split. + * @return List of non-empty string pieces. + */ + public static List splitByRoleAndHistoryMarkers(String renderedString) { + return splitByRegex(renderedString, ROLE_AND_HISTORY_MARKER_PATTERN); + } + + /** + * Splits a source string by media and section markers. + * + * @param source The source string to split. + * @return List of non-empty string pieces. + */ + public static List splitByMediaAndSectionMarkers(String source) { + return splitByRegex(source, MEDIA_AND_SECTION_MARKER_PATTERN); + } + + /** + * Converts a rendered template string into a list of messages. + * + *

Processes role markers and history placeholders to structure the conversation. + * + * @param renderedString The rendered template string to convert. + * @param data Optional data containing message history. + * @return List of structured messages. + */ + public static List toMessages(String renderedString, DataArgument data) { + MessageSource currentMessage = new MessageSource(Role.USER, ""); + List messageSources = new ArrayList<>(); + messageSources.add(currentMessage); + + for (String piece : splitByRoleAndHistoryMarkers(renderedString)) { + if (piece.startsWith(ROLE_MARKER_PREFIX)) { + String roleName = piece.substring(ROLE_MARKER_PREFIX.length()); + Role role = Role.fromString(roleName); + + if (currentMessage.source != null && !currentMessage.source.trim().isEmpty()) { + // Current message has content, create a new message + currentMessage = new MessageSource(role, ""); + messageSources.add(currentMessage); + } else { + // Update the role of the current empty message + currentMessage.role = role; + } + } else if (piece.startsWith(HISTORY_MARKER_PREFIX)) { + // Add the history messages to the message sources + List historyMessages = + transformMessagesToHistory(data != null ? data.messages() : List.of()); + for (Message msg : historyMessages) { + messageSources.add(new MessageSource(msg.role(), msg.content(), msg.metadata())); + } + + // Add a new message source for the model + currentMessage = new MessageSource(Role.MODEL, ""); + messageSources.add(currentMessage); + } else { + // Add the piece to the current message source + currentMessage.source = + (currentMessage.source != null ? currentMessage.source : "") + piece; + } + } + + List messages = messageSourcesToMessages(messageSources); + return insertHistory(messages, data != null ? data.messages() : null); + } + + /** + * Converts a rendered template string into a list of messages with no data context. + * + * @param renderedString The rendered template string to convert. + * @return List of structured messages. + */ + public static List toMessages(String renderedString) { + return toMessages(renderedString, null); + } + + /** + * Transforms an array of messages by adding history metadata to each message. + * + * @param messages Array of messages to transform. + * @return Array of messages with history metadata added. + */ + public static List transformMessagesToHistory(List messages) { + if (messages == null) { + return List.of(); + } + return messages.stream() + .map( + m -> { + Map metadata = new HashMap<>(); + if (m.metadata() != null) { + metadata.putAll(m.metadata()); + } + metadata.put("purpose", "history"); + return new Message(m.role(), m.content(), metadata); + }) + .collect(Collectors.toList()); + } + + /** + * Checks if the messages have history metadata. + * + * @param messages The messages to check. + * @return True if any message has history metadata. + */ + public static boolean messagesHaveHistory(List messages) { + if (messages == null) { + return false; + } + return messages.stream() + .anyMatch(m -> m.metadata() != null && "history".equals(m.metadata().get("purpose"))); + } + + /** + * Inserts historical messages into the conversation at appropriate positions. + * + *

The history is inserted: + * + *

+ * + * @param messages Current array of messages. + * @param history Historical messages to insert. + * @return Messages with history inserted. + */ + public static List insertHistory(List messages, List history) { + // If we have no history or find an existing instance of history, return original + if (history == null || history.isEmpty() || messagesHaveHistory(messages)) { + return messages; + } + + // If there are no messages, return the history + if (messages == null || messages.isEmpty()) { + return history; + } + + Message lastMessage = messages.get(messages.size() - 1); + if (lastMessage.role() == Role.USER) { + // Insert history before the last user message + List result = new ArrayList<>(messages.subList(0, messages.size() - 1)); + result.addAll(history); + result.add(lastMessage); + return result; + } + + // Append history to the end + List result = new ArrayList<>(messages); + result.addAll(history); + return result; + } + + /** + * Converts a source string into a list of parts, processing media and section markers. + * + * @param source The source string to convert into parts. + * @return List of structured parts (text, media, or metadata). + */ + public static List toParts(String source) { + if (source == null || source.isEmpty()) { + return List.of(); + } + return splitByMediaAndSectionMarkers(source).stream() + .map(Parser::parsePart) + .collect(Collectors.toList()); + } + + /** + * Parses a part from a string. + * + * @param piece The piece to parse. + * @return Parsed part (TextPart, MediaPart, or PendingPart). + */ + public static Part parsePart(String piece) { + if (piece.startsWith(MEDIA_MARKER_PREFIX)) { + return parseMediaPart(piece); + } + if (piece.startsWith(SECTION_MARKER_PREFIX)) { + return parseSectionPart(piece); + } + return parseTextPart(piece); + } + + /** + * Parses a media part from a string. + * + * @param piece The piece to parse. + * @return Parsed media part. + * @throws IllegalArgumentException If the piece is not a valid media marker. + */ + public static MediaPart parseMediaPart(String piece) { + if (!piece.startsWith(MEDIA_MARKER_PREFIX)) { + throw new IllegalArgumentException("Invalid media piece: " + piece); + } + String[] parts = piece.split(" "); + String url = parts.length > 1 ? parts[1] : ""; + String contentType = parts.length > 2 ? parts[2] : null; + + MediaContent media = + contentType != null && !contentType.trim().isEmpty() + ? new MediaContent(url, contentType) + : new MediaContent(url, null); + return new MediaPart(media); + } + + /** + * Parses a section part from a string. + * + * @param piece The piece to parse. + * @return Parsed pending part with section metadata. + * @throws IllegalArgumentException If the piece is not a valid section marker. + */ + public static PendingPart parseSectionPart(String piece) { + if (!piece.startsWith(SECTION_MARKER_PREFIX)) { + throw new IllegalArgumentException("Invalid section piece: " + piece); + } + String[] parts = piece.split(" "); + String sectionType = parts.length > 1 ? parts[1] : ""; + Map metadata = new HashMap<>(); + metadata.put("purpose", sectionType); + metadata.put("pending", true); + return new PendingPart(metadata); + } + + /** + * Parses a text part from a string. + * + * @param piece The piece to parse. + * @return Parsed text part. + */ + public static TextPart parseTextPart(String piece) { + return new TextPart(piece); + } + + /** + * Processes an array of message sources into an array of messages. + * + * @param messageSources List of message sources. + * @return List of structured messages. + */ + private static List messageSourcesToMessages(List messageSources) { + List messages = new ArrayList<>(); + for (MessageSource m : messageSources) { + if (m.content != null || (m.source != null && !m.source.isEmpty())) { + List content = m.content != null ? m.content : toParts(m.source); + Message message = new Message(m.role, content, m.metadata); + messages.add(message); + } + } + return messages; + } + + /** + * Expands dot-separated keys in the configuration into nested maps. + * + *

Known top-level keys are preserved. Unknown keys are moved into an 'ext' map. + * + * @param input The raw configuration map. + * @return A new map with namespaces expanded. + */ + private static Map expandNamespacedKeys(Map input) { + Map result = new HashMap<>(); + Map ext = new HashMap<>(); + + for (Map.Entry entry : input.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + + if (RESERVED_METADATA_KEYWORDS.contains(key)) { + result.put(key, value); + } else { + // Expand namespace into ext + addNested(ext, key, value); + } + } + + if (!ext.isEmpty()) { + result.put("ext", ext); + } + + return result; + } + + /** + * Adds a namespaced key to a map structure using "last dot" flattening logic. + * + *

e.g. "a.b.c" -> { "a.b": { "c": value } } + * + * @param root The root map to add to. + * @param key The dot-separated key (e.g., "a.b.c"). + * @param value The value to set. + */ + @SuppressWarnings("unchecked") + private static void addNested(Map root, String key, Object value) { + int lastDot = key.lastIndexOf('.'); + if (lastDot == -1) { + root.put(key, value); + } else { + String parentKey = key.substring(0, lastDot); + String childKey = key.substring(lastDot + 1); + + if (!root.containsKey(parentKey) || !(root.get(parentKey) instanceof Map)) { + root.put(parentKey, new HashMap()); + } + ((Map) root.get(parentKey)).put(childKey, value); + } + } + + /** Internal class to represent a message source during parsing. */ + private static class MessageSource { + Role role; + String source; + List content; + Map metadata; + + MessageSource(Role role, String source) { + this.role = role; + this.source = source; + this.content = null; + this.metadata = null; + } + + MessageSource(Role role, List content, Map metadata) { + this.role = role; + this.source = null; + this.content = content; + this.metadata = metadata; + } + } +} diff --git a/java/com/google/dotprompt/parser/ParserTest.java b/java/com/google/dotprompt/parser/ParserTest.java new file mode 100644 index 000000000..7aa07bfa4 --- /dev/null +++ b/java/com/google/dotprompt/parser/ParserTest.java @@ -0,0 +1,523 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.dotprompt.parser; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.dotprompt.models.DataArgument; +import com.google.dotprompt.models.MediaPart; +import com.google.dotprompt.models.Message; +import com.google.dotprompt.models.Part; +import com.google.dotprompt.models.PendingPart; +import com.google.dotprompt.models.Prompt; +import com.google.dotprompt.models.Role; +import com.google.dotprompt.models.TextPart; +import java.io.IOException; +import java.util.List; +import java.util.Map; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Tests for the Parser class. */ +@RunWith(JUnit4.class) +public class ParserTest { + + @Test + public void testParseWithFrontmatter() throws IOException { + String content = + "---\n" + + "input:\n" + + " schema:\n" + + " type: object\n" + + "---\n" + + "Start of the template."; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo("Start of the template."); + assertThat(prompt.config()).containsKey("input"); + } + + @Test + public void testParseWithoutFrontmatter() throws IOException { + String content = "Just a template."; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo("Just a template."); + assertThat(prompt.config()).isEmpty(); + } + + @Test + public void testParseEmptyFrontmatter() throws IOException { + String content = "---\n---\nTemplate"; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo("Template"); + assertThat(prompt.config()).isEmpty(); + } + + @Test + public void testParseWhitespacePreservation() throws IOException { + String content = "---\nfoo: bar\n---\n Indented.\n"; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo(" Indented.\n"); + } + + @Test + public void testParseCRLF() throws IOException { + String content = "---\r\nfoo: bar\r\n---\r\nBody"; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo("Body"); + @SuppressWarnings("unchecked") + Map ext = (Map) prompt.config().get("ext"); + assertThat(ext).containsEntry("foo", "bar"); + } + + @Test + public void testParseMultilineFrontmatter() throws IOException { + String content = "---\nfoo: bar\nbaz: qux\n---\nBody"; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo("Body"); + @SuppressWarnings("unchecked") + Map ext = (Map) prompt.config().get("ext"); + assertThat(ext).containsEntry("foo", "bar"); + assertThat(ext).containsEntry("baz", "qux"); + } + + @Test + public void testParseExtraMarkers() throws IOException { + String content = "---\nfoo: bar\n---\nBody\n---\nExtra"; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo("Body\n---\nExtra"); + @SuppressWarnings("unchecked") + Map ext = (Map) prompt.config().get("ext"); + assertThat(ext).containsEntry("foo", "bar"); + } + + @Test + public void testParseWithCR() throws IOException { + String content = "---\rfoo: bar\r---\rBody"; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo("Body"); + @SuppressWarnings("unchecked") + Map ext = (Map) prompt.config().get("ext"); + assertThat(ext).containsEntry("foo", "bar"); + } + + @Test + public void testParseFrontmatterWithExtraSpaces() throws IOException { + String content = "--- \nfoo: bar\n--- \nBody"; + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo("Body"); + @SuppressWarnings("unchecked") + Map ext = (Map) prompt.config().get("ext"); + assertThat(ext).containsEntry("foo", "bar"); + } + + @Test + public void testParseNamespacedKeys() throws IOException { + String content = "---\na.b.c: val\n---\nBody"; + Prompt prompt = Parser.parse(content); + @SuppressWarnings("unchecked") + Map ext = (Map) prompt.config().get("ext"); + + // Expect: { "a.b": { "c": "val" } } + assertThat(ext).containsKey("a.b"); + @SuppressWarnings("unchecked") + Map ab = (Map) ext.get("a.b"); + assertThat(ab).containsEntry("c", "val"); + } + + @Test + public void testParseIncompleteFrontmatter() throws IOException { + String content = "---\nfoo: bar\nBody"; // Missing second marker + Prompt prompt = Parser.parse(content); + assertThat(prompt.template()).isEqualTo(content); + assertThat(prompt.config()).isEmpty(); + } + + @Test + public void testRoleAndHistoryMarkerPattern_validPatterns() { + String[] validPatterns = { + "<<>>", + "<<>>", + "<<>>", + "<<>>", + "<<>>", + "<<>>" + }; + + for (String pattern : validPatterns) { + assertThat(Parser.ROLE_AND_HISTORY_MARKER_PATTERN.matcher(pattern).find()).isTrue(); + } + } + + @Test + public void testRoleAndHistoryMarkerPattern_invalidPatterns() { + String[] invalidPatterns = { + "<<>>", // uppercase not allowed + "<<>>", // numbers not allowed + "<<>>", // needs at least one letter + "<<>>", // missing role value + "<<>>", // history should be exact + "<<>>", // history must be lowercase + "dotprompt:role:user", // missing brackets + "<<>>", // incomplete opening + }; + + for (String pattern : invalidPatterns) { + assertThat(Parser.ROLE_AND_HISTORY_MARKER_PATTERN.matcher(pattern).find()).isFalse(); + } + } + + @Test + public void testSplitByRoleAndHistoryMarkers_noMarkers() { + List result = Parser.splitByRoleAndHistoryMarkers("Hello World"); + assertThat(result).containsExactly("Hello World"); + } + + @Test + public void testSplitByRoleAndHistoryMarkers_singleMarker() { + List result = + Parser.splitByRoleAndHistoryMarkers("Hello <<>> world"); + assertThat(result).containsExactly("Hello ", "<< result = Parser.splitByRoleAndHistoryMarkers(" <<>> "); + assertThat(result).containsExactly("<< result = + Parser.splitByRoleAndHistoryMarkers("<<>><<>>"); + assertThat(result).containsExactly("<< result = + Parser.splitByRoleAndHistoryMarkers( + "Start <<>> middle <<>> end"); + assertThat(result) + .containsExactly( + "Start ", "<< result = Parser.splitByMediaAndSectionMarkers("Hello World"); + assertThat(result).containsExactly("Hello World"); + } + + @Test + public void testSplitByMediaAndSectionMarkers_mediaMarker() { + List result = + Parser.splitByMediaAndSectionMarkers( + "<<>> https://example.com/image.jpg"); + assertThat(result).containsExactly("<< result = Parser.toMessages("Hello world"); + + assertThat(result).hasSize(1); + assertThat(result.get(0).role()).isEqualTo(Role.USER); + assertThat(result.get(0).content()).hasSize(1); + assertThat(((TextPart) result.get(0).content().get(0)).text()).isEqualTo("Hello world"); + } + + @Test + public void testToMessages_singleRoleMarker() { + List result = Parser.toMessages("<<>>Hello world"); + + assertThat(result).hasSize(1); + assertThat(result.get(0).role()).isEqualTo(Role.MODEL); + assertThat(((TextPart) result.get(0).content().get(0)).text()).isEqualTo("Hello world"); + } + + @Test + public void testToMessages_multipleRoleMarkers() { + String renderedString = + "<<>>System instructions\n" + + "<<>>User query\n" + + "<<>>Model response"; + List result = Parser.toMessages(renderedString); + + assertThat(result).hasSize(3); + + assertThat(result.get(0).role()).isEqualTo(Role.SYSTEM); + assertThat(((TextPart) result.get(0).content().get(0)).text()) + .isEqualTo("System instructions\n"); + + assertThat(result.get(1).role()).isEqualTo(Role.USER); + assertThat(((TextPart) result.get(1).content().get(0)).text()).isEqualTo("User query\n"); + + assertThat(result.get(2).role()).isEqualTo(Role.MODEL); + assertThat(((TextPart) result.get(2).content().get(0)).text()).isEqualTo("Model response"); + } + + @Test + public void testToMessages_updatesRoleOfEmptyMessage() { + String renderedString = "<<>><<>>Response"; + List result = Parser.toMessages(renderedString); + + // Should only have one message since first role marker has no content + assertThat(result).hasSize(1); + assertThat(result.get(0).role()).isEqualTo(Role.MODEL); + assertThat(((TextPart) result.get(0).content().get(0)).text()).isEqualTo("Response"); + } + + @Test + public void testToMessages_emptyInputString() { + List result = Parser.toMessages(""); + assertThat(result).isEmpty(); + } + + @Test + public void testToMessages_historyMarkersAddMetadata() { + String renderedString = "<<>>Query<<>>Follow-up"; + List historyMessages = + List.of( + new Message(Role.USER, List.of(new TextPart("Previous question")), null), + new Message(Role.MODEL, List.of(new TextPart("Previous answer")), null)); + + DataArgument data = new DataArgument(null, null, historyMessages, null); + List result = Parser.toMessages(renderedString, data); + + assertThat(result).hasSize(4); + + // First message is the user query + assertThat(result.get(0).role()).isEqualTo(Role.USER); + assertThat(((TextPart) result.get(0).content().get(0)).text()).isEqualTo("Query"); + + // Next two messages are history with metadata + assertThat(result.get(1).role()).isEqualTo(Role.USER); + assertThat(result.get(1).metadata()).containsEntry("purpose", "history"); + + assertThat(result.get(2).role()).isEqualTo(Role.MODEL); + assertThat(result.get(2).metadata()).containsEntry("purpose", "history"); + + // Last message is the follow-up + assertThat(result.get(3).role()).isEqualTo(Role.MODEL); + assertThat(((TextPart) result.get(3).content().get(0)).text()).isEqualTo("Follow-up"); + } + + @Test + public void testToMessages_emptyHistory() { + String renderedString = "<<>>Query<<>>Follow-up"; + DataArgument data = new DataArgument(null, null, List.of(), null); + List result = Parser.toMessages(renderedString, data); + + assertThat(result).hasSize(2); + assertThat(result.get(0).role()).isEqualTo(Role.USER); + assertThat(result.get(1).role()).isEqualTo(Role.MODEL); + } + + @Test + public void testTransformMessagesToHistory_addsMetadata() { + List messages = + List.of( + new Message(Role.USER, List.of(new TextPart("Hello")), null), + new Message(Role.MODEL, List.of(new TextPart("Hi there")), null)); + + List result = Parser.transformMessagesToHistory(messages); + + assertThat(result).hasSize(2); + assertThat(result.get(0).metadata()).containsEntry("purpose", "history"); + assertThat(result.get(1).metadata()).containsEntry("purpose", "history"); + } + + @Test + public void testTransformMessagesToHistory_preservesExistingMetadata() { + List messages = + List.of(new Message(Role.USER, List.of(new TextPart("Hello")), Map.of("foo", "bar"))); + + List result = Parser.transformMessagesToHistory(messages); + + assertThat(result).hasSize(1); + assertThat(result.get(0).metadata()).containsEntry("foo", "bar"); + assertThat(result.get(0).metadata()).containsEntry("purpose", "history"); + } + + @Test + public void testTransformMessagesToHistory_emptyArray() { + List result = Parser.transformMessagesToHistory(List.of()); + assertThat(result).isEmpty(); + } + + @Test + public void testMessagesHaveHistory_true() { + List messages = + List.of( + new Message(Role.USER, List.of(new TextPart("Hello")), Map.of("purpose", "history"))); + + assertThat(Parser.messagesHaveHistory(messages)).isTrue(); + } + + @Test + public void testMessagesHaveHistory_false() { + List messages = List.of(new Message(Role.USER, List.of(new TextPart("Hello")), null)); + + assertThat(Parser.messagesHaveHistory(messages)).isFalse(); + } + + @Test + public void testInsertHistory_returnsOriginalIfNoHistory() { + List messages = List.of(new Message(Role.USER, List.of(new TextPart("Hello")), null)); + + List result = Parser.insertHistory(messages, List.of()); + + assertThat(result).isEqualTo(messages); + } + + @Test + public void testInsertHistory_returnsOriginalIfHistoryExists() { + List messages = + List.of( + new Message(Role.USER, List.of(new TextPart("Hello")), Map.of("purpose", "history"))); + + List history = + List.of( + new Message( + Role.MODEL, List.of(new TextPart("Previous")), Map.of("purpose", "history"))); + + List result = Parser.insertHistory(messages, history); + + assertThat(result).isEqualTo(messages); + } + + @Test + public void testInsertHistory_insertsBeforeLastUserMessage() { + List messages = + List.of( + new Message(Role.SYSTEM, List.of(new TextPart("System prompt")), null), + new Message(Role.USER, List.of(new TextPart("Current question")), null)); + + List history = + List.of( + new Message( + Role.MODEL, List.of(new TextPart("Previous")), Map.of("purpose", "history"))); + + List result = Parser.insertHistory(messages, history); + + assertThat(result).hasSize(3); + assertThat(result.get(0).role()).isEqualTo(Role.SYSTEM); + assertThat(result.get(1).role()).isEqualTo(Role.MODEL); + assertThat(result.get(1).metadata()).containsEntry("purpose", "history"); + assertThat(result.get(2).role()).isEqualTo(Role.USER); + } + + @Test + public void testInsertHistory_appendsIfNoUserMessageIsLast() { + List messages = + List.of( + new Message(Role.SYSTEM, List.of(new TextPart("System prompt")), null), + new Message(Role.MODEL, List.of(new TextPart("Model message")), null)); + + List history = + List.of( + new Message( + Role.MODEL, List.of(new TextPart("Previous")), Map.of("purpose", "history"))); + + List result = Parser.insertHistory(messages, history); + + assertThat(result).hasSize(3); + assertThat(result.get(0).role()).isEqualTo(Role.SYSTEM); + assertThat(result.get(1).role()).isEqualTo(Role.MODEL); + assertThat(result.get(2).role()).isEqualTo(Role.MODEL); + assertThat(result.get(2).metadata()).containsEntry("purpose", "history"); + } + + @Test + public void testToParts_simpleText() { + List result = Parser.toParts("Hello World"); + assertThat(result).hasSize(1); + assertThat(result.get(0)).isInstanceOf(TextPart.class); + assertThat(((TextPart) result.get(0)).text()).isEqualTo("Hello World"); + } + + @Test + public void testToParts_emptyString() { + List result = Parser.toParts(""); + assertThat(result).isEmpty(); + } + + @Test + public void testParsePart_textPart() { + Part result = Parser.parsePart("Hello World"); + assertThat(result).isInstanceOf(TextPart.class); + assertThat(((TextPart) result).text()).isEqualTo("Hello World"); + } + + @Test + public void testParsePart_mediaPart() { + Part result = Parser.parsePart("<<>> https://example.com/image.jpg"); + assertThat(result).isInstanceOf(MediaPart.class); + assertThat(((MediaPart) result).media().url()).isEqualTo("https://example.com/image.jpg"); + } + + @Test + public void testParsePart_sectionPart() { + Part result = Parser.parsePart("<<>> code"); + assertThat(result).isInstanceOf(PendingPart.class); + assertThat(((PendingPart) result).metadata()).containsEntry("purpose", "code"); + assertThat(((PendingPart) result).metadata()).containsEntry("pending", true); + } + + @Test + public void testParseMediaPart_basic() { + MediaPart result = + Parser.parseMediaPart("<<>> https://example.com/image.jpg"); + assertThat(result.media().url()).isEqualTo("https://example.com/image.jpg"); + assertThat(result.media().contentType()).isNull(); + } + + @Test + public void testParseMediaPart_withContentType() { + MediaPart result = + Parser.parseMediaPart("<<>> https://example.com/image.jpg image/jpeg"); + assertThat(result.media().url()).isEqualTo("https://example.com/image.jpg"); + assertThat(result.media().contentType()).isEqualTo("image/jpeg"); + } + + @Test(expected = IllegalArgumentException.class) + public void testParseMediaPart_invalidPrefix() { + Parser.parseMediaPart("https://example.com/image.jpg"); + } + + @Test + public void testParseSectionPart_basic() { + PendingPart result = Parser.parseSectionPart("<<>> code"); + assertThat(result.metadata()).containsEntry("purpose", "code"); + assertThat(result.metadata()).containsEntry("pending", true); + } + + @Test(expected = IllegalArgumentException.class) + public void testParseSectionPart_invalidPrefix() { + Parser.parseSectionPart("code"); + } + + @Test + public void testParseTextPart() { + TextPart result = Parser.parseTextPart("Hello World"); + assertThat(result.text()).isEqualTo("Hello World"); + } +} diff --git a/java/com/google/dotprompt/parser/Picoschema.java b/java/com/google/dotprompt/parser/Picoschema.java new file mode 100644 index 000000000..e82f56475 --- /dev/null +++ b/java/com/google/dotprompt/parser/Picoschema.java @@ -0,0 +1,304 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.dotprompt.parser; + +import com.google.dotprompt.resolvers.SchemaResolver; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +/** + * Picoschema parser and related helpers. + * + *

Picoschema is a compact, YAML-optimized schema definition format. This class compiles + * Picoschema to JSON Schema. + * + *

Supported features: + * + *

    + *
  • Scalar types (any, boolean, integer, null, number, string) + *
  • Type descriptions (e.g. {@code name(string, The name)}) + *
  • Optional fields (denoted by {@code ?}) + *
  • Arrays ({@code items(string)}) + *
  • Objects (nested maps) + *
  • Enums ({@code status(enum, [active, inactive])}) + *
  • Wildcard properties ({@code (*)}) + *
+ */ +public class Picoschema { + + private static final Set JSON_SCHEMA_SCALAR_TYPES = + Set.of("any", "boolean", "integer", "null", "number", "string"); + + private static final String WILDCARD_PROPERTY_NAME = "(*)"; + + /** Checks if a schema is already in JSON Schema format. */ + public static boolean isJsonSchema(Object schema) { + if (!(schema instanceof Map)) { + return false; + } + Map map = (Map) schema; + Object type = map.get("type"); + if (type == null) { + return map.containsKey("properties"); + } + String typeStr = (String) type; + return JSON_SCHEMA_SCALAR_TYPES.contains(typeStr) + || "object".equals(typeStr) + || "array".equals(typeStr); + } + + /** + * Parses a Picoschema definition into a JSON Schema. + * + * @param schema The Picoschema definition (can be a Map or String). + * @return A future containing the equivalent JSON Schema. + */ + public static CompletableFuture> parse(Object schema) { + return parse(schema, null); + } + + /** + * Parses a Picoschema definition into a JSON Schema with reference resolution. + * + * @param schema The Picoschema definition. + * @param resolver A function to resolve named schemas asynchronously. + * @return A future containing the equivalent JSON Schema. + */ + public static CompletableFuture> parse( + Object schema, SchemaResolver resolver) { + if (schema == null) { + return CompletableFuture.completedFuture(null); + } + + if (schema instanceof String) { + Description desc = extractDescription((String) schema); + String typeName = desc.type; + String description = desc.description; + + if (JSON_SCHEMA_SCALAR_TYPES.contains(typeName)) { + Map out = new HashMap<>(); + if (!"any".equals(typeName)) { + out.put("type", typeName); + } + if (description != null) { + out.put("description", description); + } + return CompletableFuture.completedFuture(out); + } + + // Resolve named schema asynchronously + if (resolver != null) { + return resolver + .resolve(typeName) + .thenApply( + resolved -> { + if (resolved != null) { + Map out = new HashMap<>(resolved); + if (description != null) { + out.put("description", description); + } + return out; + } + throw new IllegalArgumentException("Unsupported scalar type: " + typeName); + }); + } + + throw new IllegalArgumentException("Unsupported scalar type: " + typeName); + } + + if (schema instanceof Map) { + if (isJsonSchema(schema)) { + Map map = (Map) schema; + if (!map.containsKey("type") && map.containsKey("properties")) { + Map newSchema = new HashMap<>((Map) schema); + newSchema.put("type", "object"); + return CompletableFuture.completedFuture(newSchema); + } + return CompletableFuture.completedFuture((Map) schema); + } + return parsePico((Map) schema, resolver); + } + + throw new IllegalArgumentException( + "Picoschema must be a string or object. Got: " + schema.getClass()); + } + + /** + * Parses a Picoschema object definition asynchronously. + * + * @param obj The map representing the object schema. + * @param resolver The schema resolver for named types. + * @return A future containing the JSON Schema object definition. + */ + @SuppressWarnings("unchecked") + private static CompletableFuture> parsePico( + Map obj, SchemaResolver resolver) { + + Map schema = new HashMap<>(); + schema.put("type", "object"); + schema.put("properties", new HashMap()); + schema.put("additionalProperties", false); + List required = new ArrayList<>(); + schema.put("required", required); + + List> futures = new ArrayList<>(); + + for (Map.Entry entry : obj.entrySet()) { + String key = entry.getKey(); + Object value = entry.getValue(); + + if (WILDCARD_PROPERTY_NAME.equals(key)) { + futures.add( + parse(value, resolver) + .thenAccept( + parsed -> { + schema.put("additionalProperties", parsed); + })); + continue; + } + + String name; + String typeInfo = null; + int parenIndex = key.indexOf('('); + if (parenIndex != -1 && key.endsWith(")")) { + name = key.substring(0, parenIndex); + typeInfo = key.substring(parenIndex + 1, key.length() - 1); + } else { + name = key; + } + + boolean isOptional = name.endsWith("?"); + String propertyName = isOptional ? name.substring(0, name.length() - 1) : name; + + if (!isOptional) { + required.add(propertyName); + } + + final boolean finalIsOptional = isOptional; + final String finalPropertyName = propertyName; + + if (typeInfo == null) { + futures.add( + parse(value, resolver) + .thenAccept( + prop -> { + applyOptionalNullability(prop, finalIsOptional); + ((Map) schema.get("properties")).put(finalPropertyName, prop); + })); + } else { + Description typeDesc = extractDescription(typeInfo); + String typeName = typeDesc.type; + String description = typeDesc.description; + + if ("array".equals(typeName)) { + futures.add( + parse(value, resolver) + .thenAccept( + items -> { + Map prop = new HashMap<>(); + prop.put( + "type", finalIsOptional ? Arrays.asList("array", "null") : "array"); + prop.put("items", items); + if (description != null) { + prop.put("description", description); + } + ((Map) schema.get("properties")) + .put(finalPropertyName, prop); + })); + } else if ("object".equals(typeName)) { + futures.add( + parse(value, resolver) + .thenAccept( + prop -> { + applyOptionalNullability(prop, finalIsOptional); + if (description != null) { + prop.put("description", description); + } + ((Map) schema.get("properties")) + .put(finalPropertyName, prop); + })); + } else if ("enum".equals(typeName)) { + Map prop = new HashMap<>(); + if (finalIsOptional && value instanceof List) { + List enums = new ArrayList<>((List) value); + if (!enums.contains(null)) { + enums.add(null); + } + prop.put("enum", enums); + } else { + prop.put("enum", value); + } + if (description != null) { + prop.put("description", description); + } + ((Map) schema.get("properties")).put(finalPropertyName, prop); + } else { + throw new IllegalArgumentException( + "Picoschema: parenthetical types must be 'object', 'array' or 'enum', got: " + + typeName); + } + } + } + + return CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])) + .thenApply( + v -> { + if (required.isEmpty()) { + schema.remove("required"); + } + return schema; + }); + } + + /** Applies nullability to optional properties. */ + @SuppressWarnings("unchecked") + private static void applyOptionalNullability(Map prop, boolean isOptional) { + if (!isOptional) return; + + Object currentType = prop.get("type"); + if (currentType instanceof String) { + prop.put("type", Arrays.asList(currentType, "null")); + } else if (prop.containsKey("enum")) { + List enums = new ArrayList<>((List) prop.get("enum")); + if (!enums.contains(null)) { + enums.add(null); + } + prop.put("enum", enums); + } + } + + /** Internal record for parsed type descriptions. */ + private record Description(String type, String description) {} + + /** Extracts type and description from a string like "string, The name". */ + private static Description extractDescription(String input) { + if (!input.contains(",")) { + return new Description(input, null); + } + int idx = input.indexOf(','); + String type = input.substring(0, idx).trim(); + String desc = input.substring(idx + 1).trim(); + return new Description(type, desc); + } +} diff --git a/java/com/google/dotprompt/parser/PicoschemaTest.java b/java/com/google/dotprompt/parser/PicoschemaTest.java new file mode 100644 index 000000000..c7cb9cb41 --- /dev/null +++ b/java/com/google/dotprompt/parser/PicoschemaTest.java @@ -0,0 +1,449 @@ +/* + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +package com.google.dotprompt.parser; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.dotprompt.resolvers.SchemaResolver; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Comprehensive tests for the Picoschema parser. */ +@RunWith(JUnit4.class) +public class PicoschemaTest { + + private static Map parseSync(Object schema) + throws ExecutionException, InterruptedException { + return Picoschema.parse(schema).get(); + } + + private static Map parseSync(Object schema, SchemaResolver resolver) + throws ExecutionException, InterruptedException { + return Picoschema.parse(schema, resolver).get(); + } + + @Test + public void parse_nullSchema_returnsNull() throws Exception { + assertThat(Picoschema.parse(null).get()).isNull(); + } + + @Test + public void parse_scalarTypeSchema() throws Exception { + assertThat(parseSync("string")).isEqualTo(Map.of("type", "string")); + } + + @Test + public void parse_scalarTypeNumber() throws Exception { + assertThat(parseSync("number")).isEqualTo(Map.of("type", "number")); + } + + @Test + public void parse_scalarTypeInteger() throws Exception { + assertThat(parseSync("integer")).isEqualTo(Map.of("type", "integer")); + } + + @Test + public void parse_scalarTypeBoolean() throws Exception { + assertThat(parseSync("boolean")).isEqualTo(Map.of("type", "boolean")); + } + + @Test + public void parse_scalarTypeNull() throws Exception { + assertThat(parseSync("null")).isEqualTo(Map.of("type", "null")); + } + + @Test + public void parse_anyType_noTypeField() throws Exception { + // 'any' type should produce an empty schema (no type field) + assertThat(parseSync("any")).isEqualTo(Map.of()); + } + + @Test + public void parse_objectSchema() throws Exception { + Map schema = + Map.of("type", "object", "properties", Map.of("name", Map.of("type", "string"))); + assertThat(parseSync(schema)).isEqualTo(schema); + } + + @Test + public void parse_invalidSchemaType() { + assertThrows(IllegalArgumentException.class, () -> parseSync(123)); + } + + @Test + public void parse_namedSchema() throws Exception { + SchemaResolver resolver = + SchemaResolver.fromSync( + name -> { + if ("CustomType".equals(name)) return Map.of("type", "integer"); + return null; + }); + Map result = parseSync("CustomType", resolver); + assertThat(result).isEqualTo(Map.of("type", "integer")); + } + + @Test + public void parse_namedSchemaWithDescription() throws Exception { + SchemaResolver resolver = + SchemaResolver.fromSync( + name -> { + if ("DescribedType".equals(name)) return Map.of("type", "boolean"); + return null; + }); + Map result = parseSync("DescribedType, this is a description", resolver); + assertThat(result).isEqualTo(Map.of("type", "boolean", "description", "this is a description")); + } + + @Test + public void parse_namedSchemaNotFound_throwsError() { + SchemaResolver resolver = SchemaResolver.fromSync(name -> null); + ExecutionException exception = + assertThrows(ExecutionException.class, () -> parseSync("NonExistentSchema", resolver)); + assertThat(exception.getCause()).isInstanceOf(IllegalArgumentException.class); + assertThat(exception.getCause().getMessage()).contains("Unsupported scalar type"); + } + + @Test + public void parse_namedSchemaNoResolver_throwsError() { + assertThrows(IllegalArgumentException.class, () -> parseSync("CustomSchema")); + } + + @Test + public void parse_scalarTypeSchemaWithDescription() throws Exception { + assertThat(parseSync("string, a string")) + .isEqualTo(Map.of("type", "string", "description", "a string")); + } + + @Test + public void parse_anyTypeWithDescription() throws Exception { + assertThat(parseSync("any, can be any type")) + .isEqualTo(Map.of("description", "can be any type")); + } + + @Test + public void parse_propertiesObjectShorthand() throws Exception { + Map schema = Map.of("name", "string"); + Map expected = + Map.of( + "type", + "object", + "properties", + Map.of("name", Map.of("type", "string")), + "required", + List.of("name"), + "additionalProperties", + false); + assertThat(parseSync(schema)).isEqualTo(expected); + } + + @Test + public void parse_propertiesObjectShorthandMultipleFields() throws Exception { + // Using HashMap because Map.of doesn't guarantee order + Map schema = new HashMap<>(); + schema.put("name", "string"); + schema.put("age", "integer"); + + Map result = parseSync(schema); + + assertThat(result).containsEntry("type", "object"); + assertThat(result).containsKey("properties"); + @SuppressWarnings("unchecked") + Map props = (Map) result.get("properties"); + assertThat(props).containsKey("name"); + assertThat(props).containsKey("age"); + } + + @Test + public void parse_picoArrayType() throws Exception { + Map schema = Map.of("names(array)", "string"); + Map expected = + Map.of( + "type", + "object", + "properties", + Map.of("names", Map.of("type", "array", "items", Map.of("type", "string"))), + "required", + List.of("names"), + "additionalProperties", + false); + assertThat(parseSync(schema)).isEqualTo(expected); + } + + @Test + public void parse_picoArrayTypeWithDescription() throws Exception { + Map schema = Map.of("items(array, list of items)", "string"); + Map result = parseSync(schema); + + assertThat(result).containsEntry("type", "object"); + @SuppressWarnings("unchecked") + Map props = (Map) result.get("properties"); + @SuppressWarnings("unchecked") + Map itemsProp = (Map) props.get("items"); + assertThat(itemsProp).containsEntry("type", "array"); + assertThat(itemsProp).containsEntry("description", "list of items"); + } + + @Test + public void parse_picoOptionalArrayWithDescription() throws Exception { + Map schema = Map.of("items?(array, list of items)", "string"); + Map result = parseSync(schema); + + assertThat(result).containsEntry("type", "object"); + @SuppressWarnings("unchecked") + Map props = (Map) result.get("properties"); + @SuppressWarnings("unchecked") + Map itemsProp = (Map) props.get("items"); + assertThat(itemsProp).containsEntry("type", Arrays.asList("array", "null")); + assertThat(itemsProp).containsEntry("description", "list of items"); + // Optional properties should not be in required + assertThat(result).doesNotContainKey("required"); + } + + @Test + public void parse_picoNestedArray() throws Exception { + // Nested array: items(array) containing props(array) + Map innerSchema = Map.of("props(array)", "string"); + Map schema = Map.of("items(array)", innerSchema); + Map result = parseSync(schema); + + assertThat(result).containsEntry("type", "object"); + @SuppressWarnings("unchecked") + Map props = (Map) result.get("properties"); + @SuppressWarnings("unchecked") + Map itemsProp = (Map) props.get("items"); + assertThat(itemsProp).containsEntry("type", "array"); + @SuppressWarnings("unchecked") + Map itemsItems = (Map) itemsProp.get("items"); + assertThat(itemsItems).containsEntry("type", "object"); + @SuppressWarnings("unchecked") + Map innerProps = (Map) itemsItems.get("properties"); + assertThat(innerProps).containsKey("props"); + } + + @Test + public void parse_picoEnumType() throws Exception { + Map schema = Map.of("status(enum)", List.of("active", "inactive")); + Map expected = + Map.of( + "type", + "object", + "properties", + Map.of("status", Map.of("enum", List.of("active", "inactive"))), + "required", + List.of("status"), + "additionalProperties", + false); + assertThat(parseSync(schema)).isEqualTo(expected); + } + + @Test + public void parse_picoEnumWithDescription() throws Exception { + Map schema = Map.of("status(enum, the status)", List.of("active", "inactive")); + Map result = parseSync(schema); + + assertThat(result).containsEntry("type", "object"); + @SuppressWarnings("unchecked") + Map props = (Map) result.get("properties"); + @SuppressWarnings("unchecked") + Map statusProp = (Map) props.get("status"); + assertThat(statusProp).containsEntry("enum", List.of("active", "inactive")); + assertThat(statusProp).containsEntry("description", "the status"); + } + + @Test + public void parse_picoEnumWithOptionalAndNull() throws Exception { + Map schema = Map.of("status?(enum)", List.of("active", "inactive")); + Map expected = + Map.of( + "type", + "object", + "properties", + Map.of("status", Map.of("enum", Arrays.asList("active", "inactive", null))), + "additionalProperties", + false); + assertThat(parseSync(schema)).isEqualTo(expected); + } + + @Test + public void parse_picoOptionalProperty() throws Exception { + Map schema = Map.of("name?", "string"); + Map expected = + Map.of( + "type", + "object", + "properties", + Map.of("name", Map.of("type", Arrays.asList("string", "null"))), + "additionalProperties", + false); + assertThat(parseSync(schema)).isEqualTo(expected); + } + + @Test + public void parse_picoWildcardProperty() throws Exception { + Map schema = Map.of("(*)", "string"); + Map expected = + Map.of( + "type", "object", + "properties", Map.of(), + "additionalProperties", Map.of("type", "string")); + assertThat(parseSync(schema)).isEqualTo(expected); + } + + @Test + public void parse_picoNestedObject() throws Exception { + Map schema = Map.of("address(object)", Map.of("street", "string")); + Map expected = + Map.of( + "type", + "object", + "properties", + Map.of( + "address", + Map.of( + "type", + "object", + "properties", + Map.of("street", Map.of("type", "string")), + "required", + List.of("street"), + "additionalProperties", + false)), + "required", + List.of("address"), + "additionalProperties", + false); + assertThat(parseSync(schema)).isEqualTo(expected); + } + + @Test + public void parse_picoNestedObjectWithDescription() throws Exception { + Map schema = Map.of("address(object, the address)", Map.of("street", "string")); + Map result = parseSync(schema); + + @SuppressWarnings("unchecked") + Map props = (Map) result.get("properties"); + @SuppressWarnings("unchecked") + Map addressProp = (Map) props.get("address"); + assertThat(addressProp).containsEntry("type", "object"); + assertThat(addressProp).containsEntry("description", "the address"); + } + + @Test + public void parse_picoDescriptionOnType() throws Exception { + Map schema = Map.of("name", "string, a name"); + Map expected = + Map.of( + "type", + "object", + "properties", + Map.of("name", Map.of("type", "string", "description", "a name")), + "required", + List.of("name"), + "additionalProperties", + false); + assertThat(parseSync(schema)).isEqualTo(expected); + } + + @Test + public void parse_picoDescriptionOnCustomSchema() throws Exception { + SchemaResolver resolver = + SchemaResolver.fromSync( + name -> { + if ("CustomSchema".equals(name)) return Map.of("type", "string"); + return null; + }); + Map schema = Map.of("field1", "CustomSchema, a custom field"); + Map result = parseSync(schema, resolver); + + @SuppressWarnings("unchecked") + Map props = (Map) result.get("properties"); + @SuppressWarnings("unchecked") + Map field1Prop = (Map) props.get("field1"); + assertThat(field1Prop).containsEntry("type", "string"); + assertThat(field1Prop).containsEntry("description", "a custom field"); + } + + @Test + public void parse_asyncResolver() throws Exception { + SchemaResolver asyncResolver = + name -> { + // Simulate async operation + return CompletableFuture.supplyAsync( + () -> { + if ("AsyncType".equals(name)) return Map.of("type", "number"); + return null; + }); + }; + Map result = parseSync("AsyncType", asyncResolver); + assertThat(result).isEqualTo(Map.of("type", "number")); + } + + @Test + public void parse_jsonSchemaPassthrough() throws Exception { + Map jsonSchema = + Map.of("type", "object", "properties", Map.of("name", Map.of("type", "string"))); + assertThat(parseSync(jsonSchema)).isEqualTo(jsonSchema); + } + + @Test + public void parse_jsonSchemaWithPropertiesOnly_addsType() throws Exception { + // If 'properties' is present but 'type' is not, add type: object + Map schema = Map.of("properties", Map.of("name", Map.of("type", "string"))); + Map result = parseSync(schema); + assertThat(result).containsEntry("type", "object"); + } + + @Test + public void isJsonSchema_withTypeObject() { + assertThat(Picoschema.isJsonSchema(Map.of("type", "object"))).isTrue(); + } + + @Test + public void isJsonSchema_withTypeString() { + assertThat(Picoschema.isJsonSchema(Map.of("type", "string"))).isTrue(); + } + + @Test + public void isJsonSchema_withTypeArray() { + assertThat(Picoschema.isJsonSchema(Map.of("type", "array"))).isTrue(); + } + + @Test + public void isJsonSchema_withPropertiesOnly() { + assertThat(Picoschema.isJsonSchema(Map.of("properties", Map.of()))).isTrue(); + } + + @Test + public void isJsonSchema_picoschemaObject() { + assertThat(Picoschema.isJsonSchema(Map.of("name", "string"))).isFalse(); + } + + @Test + public void isJsonSchema_nonMap() { + assertThat(Picoschema.isJsonSchema("string")).isFalse(); + } +} diff --git a/js/src/picoschema.test.ts b/js/src/picoschema.test.ts new file mode 100644 index 000000000..1d2bb9321 --- /dev/null +++ b/js/src/picoschema.test.ts @@ -0,0 +1,371 @@ +/** + * Copyright 2025 Google LLC + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * SPDX-License-Identifier: Apache-2.0 + */ + +import { describe, expect, it } from 'vitest'; +import { PicoschemaParser, picoschema } from './picoschema'; +import type { JSONSchema, SchemaResolver } from './types'; + +describe('picoschema', () => { + describe('parse null and basic types', () => { + it('should return null for null input', async () => { + const result = await picoschema(null); + expect(result).toBeNull(); + }); + + it('should parse scalar type string', async () => { + const result = await picoschema('string'); + expect(result).toEqual({ type: 'string' }); + }); + + it('should parse scalar type number', async () => { + const result = await picoschema('number'); + expect(result).toEqual({ type: 'number' }); + }); + + it('should parse scalar type integer', async () => { + const result = await picoschema('integer'); + expect(result).toEqual({ type: 'integer' }); + }); + + it('should parse scalar type boolean', async () => { + const result = await picoschema('boolean'); + expect(result).toEqual({ type: 'boolean' }); + }); + + it('should parse scalar type null', async () => { + const result = await picoschema('null'); + expect(result).toEqual({ type: 'null' }); + }); + + it('should parse any type', async () => { + const result = await picoschema('any'); + expect(result).toEqual({ type: 'any' }); + }); + }); + + describe('descriptions', () => { + it('should parse scalar type with description', async () => { + const result = await picoschema('string, a string'); + expect(result).toEqual({ type: 'string', description: 'a string' }); + }); + + it('should parse any type with description', async () => { + const result = await picoschema('any, can be any type'); + expect(result).toEqual({ type: 'any', description: 'can be any type' }); + }); + }); + + describe('JSON Schema passthrough', () => { + it('should pass through valid JSON Schema objects', async () => { + const schema = { + type: 'object', + properties: { name: { type: 'string' } }, + }; + const result = await picoschema(schema); + expect(result).toEqual(schema); + }); + + it('should add type: object when only properties is present', async () => { + const schema = { + properties: { name: { type: 'string' } }, + }; + const result = await picoschema(schema); + expect(result).toEqual({ + type: 'object', + properties: { name: { type: 'string' } }, + }); + }); + }); + + describe('object shorthand', () => { + it('should parse properties object shorthand', async () => { + const result = await picoschema({ name: 'string' }); + expect(result).toEqual({ + type: 'object', + properties: { name: { type: 'string' } }, + required: ['name'], + additionalProperties: false, + }); + }); + + it('should parse multiple properties', async () => { + const result = await picoschema({ name: 'string', age: 'integer' }); + expect(result).toHaveProperty('type', 'object'); + expect(result).toHaveProperty('properties.name'); + expect(result).toHaveProperty('properties.age'); + }); + }); + + describe('array type', () => { + it('should parse array type', async () => { + const result = await picoschema({ 'names(array)': 'string' }); + expect(result).toEqual({ + type: 'object', + properties: { + names: { type: 'array', items: { type: 'string' } }, + }, + required: ['names'], + additionalProperties: false, + }); + }); + + it('should parse array type with description', async () => { + const result = await picoschema({ + 'items(array, list of items)': 'string', + }); + expect(result?.properties?.items).toEqual({ + type: 'array', + items: { type: 'string' }, + description: 'list of items', + }); + }); + + it('should parse optional array with description', async () => { + const result = await picoschema({ + 'items?(array, list of items)': 'string', + }); + expect(result?.properties?.items).toEqual({ + type: ['array', 'null'], + items: { type: 'string' }, + description: 'list of items', + }); + expect(result?.required).toBeUndefined(); + }); + + it('should parse nested arrays', async () => { + const result = await picoschema({ + 'items(array)': { 'props(array)': 'string' }, + }); + expect(result?.properties?.items?.type).toBe('array'); + expect(result?.properties?.items?.items?.type).toBe('object'); + expect(result?.properties?.items?.items?.properties?.props?.type).toBe( + 'array' + ); + }); + }); + + describe('enum type', () => { + it('should parse enum type', async () => { + const result = await picoschema({ + 'status(enum)': ['active', 'inactive'], + }); + expect(result).toEqual({ + type: 'object', + properties: { status: { enum: ['active', 'inactive'] } }, + required: ['status'], + additionalProperties: false, + }); + }); + + it('should parse enum type with description', async () => { + const result = await picoschema({ + 'status(enum, the status)': ['active', 'inactive'], + }); + expect(result?.properties?.status).toEqual({ + enum: ['active', 'inactive'], + description: 'the status', + }); + }); + + it('should parse optional enum with null', async () => { + const result = await picoschema({ + 'status?(enum)': ['active', 'inactive'], + }); + expect(result?.properties?.status?.enum).toContain(null); + expect(result?.required).toBeUndefined(); + }); + }); + + describe('optional properties', () => { + it('should parse optional property', async () => { + const result = await picoschema({ 'name?': 'string' }); + expect(result).toEqual({ + type: 'object', + properties: { name: { type: ['string', 'null'] } }, + additionalProperties: false, + }); + }); + }); + + describe('wildcard properties', () => { + it('should parse wildcard property', async () => { + const result = await picoschema({ '(*)': 'string' }); + expect(result).toEqual({ + type: 'object', + properties: {}, + additionalProperties: { type: 'string' }, + }); + }); + }); + + describe('nested objects', () => { + it('should parse nested object', async () => { + const result = await picoschema({ + 'address(object)': { street: 'string' }, + }); + expect(result).toEqual({ + type: 'object', + properties: { + address: { + type: 'object', + properties: { street: { type: 'string' } }, + required: ['street'], + additionalProperties: false, + }, + }, + required: ['address'], + additionalProperties: false, + }); + }); + + it('should parse nested object with description', async () => { + const result = await picoschema({ + 'address(object, the address)': { street: 'string' }, + }); + expect(result?.properties?.address?.description).toBe('the address'); + }); + }); + + describe('description on type', () => { + it('should parse description on property type', async () => { + const result = await picoschema({ name: 'string, a name' }); + expect(result).toEqual({ + type: 'object', + properties: { + name: { type: 'string', description: 'a name' }, + }, + required: ['name'], + additionalProperties: false, + }); + }); + }); + + describe('invalid inputs', () => { + it('should throw on invalid schema type', async () => { + await expect(picoschema(123 as any)).rejects.toThrow(); + }); + + it('should throw on unsupported scalar type without resolver', async () => { + await expect(picoschema('UndefinedType')).rejects.toThrow( + /unsupported scalar type/i + ); + }); + }); +}); + +describe('PicoschemaParser', () => { + describe('schema resolution', () => { + it('should resolve named schema', async () => { + const resolver: SchemaResolver = (name) => { + if (name === 'CustomType') return { type: 'integer' }; + return null; + }; + + const parser = new PicoschemaParser({ schemaResolver: resolver }); + const result = await parser.parse('CustomType'); + expect(result).toEqual({ type: 'integer' }); + }); + + it('should resolve named schema with description', async () => { + const resolver: SchemaResolver = (name) => { + if (name === 'DescribedType') return { type: 'boolean' }; + return null; + }; + + const parser = new PicoschemaParser({ schemaResolver: resolver }); + const result = await parser.parse('DescribedType, this is a description'); + expect(result).toEqual({ + type: 'boolean', + description: 'this is a description', + }); + }); + + it('should throw when named schema not found', async () => { + const resolver: SchemaResolver = () => null; + const parser = new PicoschemaParser({ schemaResolver: resolver }); + + await expect(parser.parse('NonExistentSchema')).rejects.toThrow( + /could not find schema/i + ); + }); + + it('should resolve async schema resolver', async () => { + const resolver: SchemaResolver = async (name) => { + // Simulate async operation + await new Promise((resolve) => setTimeout(resolve, 10)); + if (name === 'AsyncType') return { type: 'number' }; + return null; + }; + + const parser = new PicoschemaParser({ schemaResolver: resolver }); + const result = await parser.parse('AsyncType'); + expect(result).toEqual({ type: 'number' }); + }); + + it('should resolve custom schema in property with description', async () => { + const resolver: SchemaResolver = (name) => { + if (name === 'CustomSchema') return { type: 'string' }; + return null; + }; + + const parser = new PicoschemaParser({ schemaResolver: resolver }); + const result = await parser.parse({ + field1: 'CustomSchema, a custom field', + }); + + expect(result?.properties?.field1).toEqual({ + type: 'string', + description: 'a custom field', + }); + }); + }); + + describe('mustResolveSchema', () => { + it('should resolve successfully', async () => { + const resolver: SchemaResolver = (name) => { + if (name === 'MySchema') + return { type: 'string', description: 'Resolved schema' }; + return null; + }; + + const parser = new PicoschemaParser({ schemaResolver: resolver }); + const result = await (parser as any).mustResolveSchema('MySchema'); + expect(result).toEqual({ + type: 'string', + description: 'Resolved schema', + }); + }); + + it('should throw when schema not found', async () => { + const resolver: SchemaResolver = () => null; + const parser = new PicoschemaParser({ schemaResolver: resolver }); + + await expect( + (parser as any).mustResolveSchema('NonExistent') + ).rejects.toThrow(); + }); + + it('should throw when no resolver configured', async () => { + const parser = new PicoschemaParser(); + + await expect( + (parser as any).mustResolveSchema('AnySchema') + ).rejects.toThrow(/unsupported scalar type/i); + }); + }); +});