Skip to content

Commit db1c58d

Browse files
Support for stream
1 parent c434700 commit db1c58d

File tree

6 files changed

+122
-23
lines changed

6 files changed

+122
-23
lines changed

Sources/PolyAI/Interfaces/Parameters/LLMMessageParameter.swift

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,31 @@
77

88
import Foundation
99

10+
// MARK: LLMMessageParameter
11+
1012
public protocol LLMMessageParameter {
1113

1214
var role: String { get }
1315
var content: String { get }
1416
}
17+
18+
// MARK: LLMMessage
19+
20+
public struct LLMMessage: LLMMessageParameter {
21+
22+
public var role: String
23+
public var content: String
24+
25+
public enum Role: String {
26+
case user
27+
case assistant
28+
}
29+
30+
public init(
31+
role: Role,
32+
content: String)
33+
{
34+
self.role = role.rawValue
35+
self.content = content
36+
}
37+
}

Sources/PolyAI/Interfaces/Parameters/LLMParameter.swift

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,15 @@
66
//
77

88
import Foundation
9+
import SwiftOpenAI
10+
import SwiftAnthropic
911

1012
public enum LLMParameter {
1113

12-
case openAI(model: String, messages: [LLMMessageParameter], maxTokens: Int? = nil, stream: Bool? = nil)
13-
case anthropic(model: String, messages: [LLMMessageParameter], maxTokens: Int, stream: Bool)
14+
case openAI(model: SwiftOpenAI.Model, messages: [LLMMessage], maxTokens: Int? = nil)
15+
case anthropic(model: SwiftAnthropic.Model, messages: [LLMMessage], maxTokens: Int)
1416

15-
var llm: String {
17+
var llmService: String {
1618
switch self {
1719
case .openAI: return "OpenAI"
1820
case .anthropic: return "Anthropic"

Sources/PolyAI/Interfaces/Response/LLMMessageResponse.swift

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,20 +36,12 @@ struct ChatUsageMetrics: UsageMetrics {
3636
public protocol ToolUsage {
3737
var toolId: String? { get }
3838
var toolName: String { get }
39-
var toolInput: [String: String]? { get } // Assuming tools might have inputs. Adjust as necessary.
39+
var toolInput: [String: String]? { get }
4040
}
4141

4242
// MARK: OpenAI
4343

4444
extension ChatCompletionObject: LLMMessageResponse {
45-
public var tools: [ToolUsage] {
46-
[]
47-
}
48-
49-
public var role: String {
50-
choices.first?.message.role ?? "unknown"
51-
}
52-
5345
public var createdAt: Int? {
5446
created
5547
}
@@ -58,9 +50,18 @@ extension ChatCompletionObject: LLMMessageResponse {
5850
choices.first?.message.content ?? ""
5951
}
6052

53+
6154
public var usageMetrics: UsageMetrics {
6255
ChatUsageMetrics(inputTokens: usage.promptTokens, outputTokens: usage.completionTokens, totalTokens: usage.totalTokens)
6356
}
57+
58+
public var tools: [ToolUsage] {
59+
[]
60+
}
61+
62+
public var role: String {
63+
choices.first?.message.role ?? "unknown"
64+
}
6465
}
6566

6667
// MARK: Anthropic
Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,35 @@
11
//
2-
// File.swift
2+
// LLMMessageStreamResponse.swift
33
//
44
//
55
// Created by James Rochabrun on 4/14/24.
66
//
77

88
import Foundation
9+
import SwiftAnthropic
10+
import SwiftOpenAI
11+
12+
// MARK: Interface
13+
14+
public protocol LLMMessageStreamResponse {
15+
var content: String? { get }
16+
}
17+
18+
// MARK: OpenAI
19+
20+
extension ChatCompletionChunkObject: LLMMessageStreamResponse {
21+
22+
public var content: String? {
23+
choices.first?.delta.content
24+
}
25+
}
26+
27+
// MARK: Anthropic
28+
29+
extension MessageStreamResponse: LLMMessageStreamResponse {
30+
31+
public var content: String? {
32+
delta?.text
33+
}
34+
}
35+

Sources/PolyAI/Service/DefaultPolyAIService.swift

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ struct DefaultPolyAIService: PolyAIService {
2424
switch configuration {
2525
case .openAI(let apiKey, let organizationID, let configuration, let decoder):
2626
openAIService = OpenAIServiceFactory.service(apiKey: apiKey, organizationID: organizationID, configuration: configuration, decoder: decoder)
27-
case .anthropic(let apiKey, let apiVersion, let configuration):
28-
anthropicService = AnthropicServiceFactory.service(apiKey: apiKey, apiVersion: apiVersion, configuration: configuration)
27+
case .anthropic(let apiKey, let configuration):
28+
anthropicService = AnthropicServiceFactory.service(apiKey: apiKey, configuration: configuration)
2929
}
3030
}
3131
}
@@ -37,20 +37,62 @@ struct DefaultPolyAIService: PolyAIService {
3737
async throws -> LLMMessageResponse
3838
{
3939
switch parameter {
40-
case .openAI(let model, let messages, let maxTokens, _):
40+
case .openAI(let model, let messages, let maxTokens):
4141
guard let openAIService else {
42-
throw PolyAIError.missingLLMConfiguration("You Must provide a valid configuration for the \(parameter.llm) API")
42+
throw PolyAIError.missingLLMConfiguration("You Must provide a valid configuration for the \(parameter.llmService) API")
4343
}
4444
let messageParams: [SwiftOpenAI.ChatCompletionParameters.Message] = messages.map { .init(role: .init(rawValue: $0.role) ?? .user, content: .text($0.content)) }
45-
let messageParameter = ChatCompletionParameters(messages: messageParams, model: .custom(model), maxTokens: maxTokens)
45+
let messageParameter = ChatCompletionParameters(messages: messageParams, model: model, maxTokens: maxTokens)
4646
return try await openAIService.startChat(parameters: messageParameter)
47-
case .anthropic(let model, let messages, let maxTokens, _):
47+
case .anthropic(let model, let messages, let maxTokens):
4848
guard let anthropicService else {
49-
throw PolyAIError.missingLLMConfiguration("You Must provide a valid configuration for the \(parameter.llm) API")
49+
throw PolyAIError.missingLLMConfiguration("You Must provide a valid configuration for the \(parameter.llmService) API")
5050
}
5151
let messageParams: [SwiftAnthropic.MessageParameter.Message] = messages.map { MessageParameter.Message(role: SwiftAnthropic.MessageParameter.Message.Role(rawValue: $0.role) ?? .user, content: .text($0.content)) }
52-
let messageParameter = MessageParameter(model: .other(model), messages: messageParams, maxTokens: maxTokens, stream: false)
52+
let messageParameter = MessageParameter(model: model, messages: messageParams, maxTokens: maxTokens, stream: false)
5353
return try await anthropicService.createMessage(messageParameter, beta: nil)
5454
}
5555
}
56+
57+
func streamMessage(
58+
_ parameter: LLMParameter)
59+
async throws -> AsyncThrowingStream<LLMMessageStreamResponse, Error>
60+
{
61+
switch parameter {
62+
case .openAI(let model, let messages, let maxTokens):
63+
guard let openAIService else {
64+
throw PolyAIError.missingLLMConfiguration("You Must provide a valid configuration for the \(parameter.llmService) API")
65+
}
66+
let messageParams: [SwiftOpenAI.ChatCompletionParameters.Message] = messages.map { .init(role: .init(rawValue: $0.role) ?? .user, content: .text($0.content)) }
67+
let messageParameter = ChatCompletionParameters(messages: messageParams, model: model, maxTokens: maxTokens)
68+
let stream = try await openAIService.startStreamedChat(parameters: messageParameter)
69+
return try mapToLLMMessageStreamResponse(stream: stream)
70+
case .anthropic(let model, let messages, let maxTokens):
71+
guard let anthropicService else {
72+
throw PolyAIError.missingLLMConfiguration("You Must provide a valid configuration for the \(parameter.llmService) API")
73+
}
74+
let messageParams: [SwiftAnthropic.MessageParameter.Message] = messages.map { MessageParameter.Message(role: SwiftAnthropic.MessageParameter.Message.Role(rawValue: $0.role) ?? .user, content: .text($0.content)) }
75+
let messageParameter = MessageParameter(model: model, messages: messageParams, maxTokens: maxTokens, stream: false)
76+
let stream = try await anthropicService.streamMessage(messageParameter, beta: nil)
77+
return try mapToLLMMessageStreamResponse(stream: stream)
78+
}
79+
}
80+
81+
private func mapToLLMMessageStreamResponse<T: LLMMessageStreamResponse>(stream: AsyncThrowingStream<T, Error>)
82+
throws -> AsyncThrowingStream<LLMMessageStreamResponse, Error>
83+
{
84+
let mappedStream = AsyncThrowingStream<LLMMessageStreamResponse, Error> { continuation in
85+
Task {
86+
do {
87+
for try await chunk in stream {
88+
continuation.yield(chunk)
89+
}
90+
continuation.finish()
91+
} catch {
92+
continuation.finish(throwing: error)
93+
}
94+
}
95+
}
96+
return mappedStream
97+
}
5698
}

Sources/PolyAI/Service/PolyAIService.swift

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import Foundation
99

1010
public enum LLMConfiguration {
11-
case openAI(apiKey: String, organizationID: String?, configuration: URLSessionConfiguration = .default, decoder: JSONDecoder = .init())
12-
case anthropic(apiKey: String, apiVersion: String, configuration: URLSessionConfiguration = .default)
11+
case openAI(apiKey: String, organizationID: String? = nil, configuration: URLSessionConfiguration = .default, decoder: JSONDecoder = .init())
12+
case anthropic(apiKey: String, configuration: URLSessionConfiguration = .default)
1313
}
1414

1515
public protocol PolyAIService {
@@ -21,4 +21,8 @@ public protocol PolyAIService {
2121
func createMessage(
2222
_ parameter: LLMParameter)
2323
async throws -> LLMMessageResponse
24+
25+
func streamMessage(
26+
_ parameter: LLMParameter)
27+
async throws -> AsyncThrowingStream<LLMMessageStreamResponse, Error>
2428
}

0 commit comments

Comments
 (0)