Skip to content

Commit cb60d77

Browse files
committed
Config-based models supporting json schema
1 parent 7d70fc0 commit cb60d77

File tree

3 files changed

+44
-8
lines changed

3 files changed

+44
-8
lines changed

openai-client/src/main/resources/openai-scala-client.conf

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# Default Open AI Scala Client Config
22

33
openai-scala-client {
4-
apiKey = ${OPENAI_SCALA_CLIENT_API_KEY}
4+
apiKey = ${?OPENAI_SCALA_CLIENT_API_KEY}
55
orgId = ${?OPENAI_SCALA_CLIENT_ORG_ID}
66

77
timeouts {

openai-client/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceFactoryHelper.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package io.cequence.openaiscala.service
22

33
import akka.stream.Materializer
44
import com.typesafe.config.{Config, ConfigFactory}
5+
import io.cequence.openaiscala.OpenAIScalaClientException
56
import io.cequence.wsclient.ConfigImplicits._
67
import io.cequence.wsclient.domain.WsRequestContext
78
import io.cequence.wsclient.service.ws.Timeouts
@@ -50,8 +51,17 @@ trait OpenAIServiceFactoryHelper[F] extends OpenAIServiceConsts {
5051
pooledConnectionIdleTimeout = intTimeoutAux("pooledConnectionIdleTimeout")
5152
)
5253

54+
val apiKey = config
55+
.optionalString(s"$configPrefix.apiKey")
56+
.getOrElse(
57+
throw new OpenAIScalaClientException(
58+
s"API key is not defined in the config at '$configPrefix.apiKey'. " +
59+
"Please set the OPENAI_SCALA_CLIENT_API_KEY environment variable or provide an API key explicitly."
60+
)
61+
)
62+
5363
apply(
54-
apiKey = config.getString(s"$configPrefix.apiKey"),
64+
apiKey = apiKey,
5565
orgId = config.optionalString(s"$configPrefix.orgId"),
5666
timeouts = timeouts.toOption
5767
)

openai-core/src/main/scala/io/cequence/openaiscala/service/OpenAIChatCompletionExtra.scala

Lines changed: 32 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ import play.api.libs.json.{Format, JsObject, JsValue, Json}
1717

1818
import scala.concurrent.{ExecutionContext, Future}
1919
import com.fasterxml.jackson.core.JsonProcessingException
20+
import com.typesafe.config.Config
2021
import io.cequence.openaiscala.OpenAIScalaClientException
2122
import io.cequence.openaiscala.domain.JsonSchema.JsonSchemaOrMap
2223
import io.cequence.wsclient.JsonUtil
@@ -127,7 +128,8 @@ object OpenAIChatCompletionExtra extends OpenAIServiceConsts {
127128
maxRetries: Option[Int] = Some(defaultMaxRetries),
128129
retryOnAnyError: Boolean = false,
129130
taskNameForLogging: Option[String] = None,
130-
jsonSchemaModels: Seq[String] = defaultModelsSupportingJsonSchema,
131+
jsonSchemaModels: Seq[String] = Nil,
132+
config: Option[Config] = None,
131133
enforceJsonSchemaMode: Boolean = false,
132134
parseJson: String => JsValue = defaultParseJsonOrRepair
133135
)(
@@ -144,6 +146,7 @@ object OpenAIChatCompletionExtra extends OpenAIServiceConsts {
144146
settings,
145147
taskNameForLoggingFinal,
146148
jsonSchemaModels,
149+
config,
147150
enforceJsonSchemaMode
148151
)
149152
} else {
@@ -203,19 +206,42 @@ object OpenAIChatCompletionExtra extends OpenAIServiceConsts {
203206
}
204207
}
205208

206-
val defaultModelsSupportingJsonSchema = {
207-
val config = loadDefaultConfig()
209+
private lazy val defaultConfig = loadDefaultConfig()
210+
211+
private def getJsonSchemaModels(
212+
jsonSchemaModels: Seq[String],
213+
config: Option[Config]
214+
): Seq[String] = {
208215
import scala.collection.JavaConverters._
209-
config.getStringList(s"$configPrefix.models-supporting-json-schema").asScala.toSeq
216+
val cfg = config.getOrElse(defaultConfig)
217+
val configPath = s"$configPrefix.models-supporting-json-schema"
218+
val configModels = if (cfg.hasPath(configPath)) {
219+
cfg.getStringList(configPath).asScala.toSeq
220+
} else {
221+
Nil
222+
}
223+
224+
if (jsonSchemaModels.isEmpty) {
225+
configModels
226+
} else if (config.isDefined) {
227+
// Both explicit models and custom config provided - merge and deduplicate
228+
(jsonSchemaModels ++ configModels).distinct
229+
} else {
230+
// Only explicit models provided
231+
jsonSchemaModels
232+
}
210233
}
211234

212235
def handleOutputJsonSchema(
213236
messages: Seq[BaseMessage],
214237
settings: CreateChatCompletionSettings,
215238
taskNameForLogging: String,
216-
jsonSchemaModels: Seq[String] = defaultModelsSupportingJsonSchema,
239+
jsonSchemaModels: Seq[String] = Nil,
240+
config: Option[Config] = None,
217241
enforceJsonSchemaMode: Boolean = false
218242
): (Seq[BaseMessage], CreateChatCompletionSettings) = {
243+
val jsonSchemaModelsFinal = getJsonSchemaModels(jsonSchemaModels, config)
244+
219245
val jsonSchemaDef = settings.jsonSchema.getOrElse(
220246
throw new IllegalArgumentException("JSON schema is not defined but expected.")
221247
)
@@ -226,7 +252,7 @@ object OpenAIChatCompletionExtra extends OpenAIServiceConsts {
226252
// to be more robust we also match models with a suffix
227253
if (
228254
enforceJsonSchemaMode ||
229-
jsonSchemaModels.exists(model =>
255+
jsonSchemaModelsFinal.exists(model =>
230256
settings.model.equals(model) || settings.model.endsWith("-" + model)
231257
)
232258
) {

0 commit comments

Comments
 (0)