Skip to content

Commit

Permalink
SDK regeneration (#33)
Browse files Browse the repository at this point in the history
Co-authored-by: fern-api <115122769+fern-api[bot]@users.noreply.github.com>
  • Loading branch information
fern-api[bot] authored Sep 20, 2024
1 parent 1cccf8c commit 05d207a
Show file tree
Hide file tree
Showing 85 changed files with 1,542 additions and 986 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public BaseType getBaseType() {
}

/**
* @return The fine-tuning strategy.
* @return Deprecated: The fine-tuning strategy.
*/
@JsonProperty("strategy")
public Optional<Strategy> getStrategy() {
Expand Down Expand Up @@ -165,7 +165,7 @@ public _FinalStage baseType(BaseType baseType) {
}

/**
* <p>The fine-tuning strategy.</p>
* <p>Deprecated: The fine-tuning strategy.</p>
* @return Reference to {@code this} so that method calls can be chained together.
*/
@java.lang.Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ public final class Hyperparameters {

private final Optional<Double> learningRate;

private final Optional<Integer> loraAlpha;

private final Optional<Integer> loraRank;

private final Optional<LoraTargetModules> loraTargetModules;

private final Map<String, Object> additionalProperties;

private Hyperparameters(
Expand All @@ -38,12 +44,18 @@ private Hyperparameters(
Optional<Integer> trainBatchSize,
Optional<Integer> trainEpochs,
Optional<Double> learningRate,
Optional<Integer> loraAlpha,
Optional<Integer> loraRank,
Optional<LoraTargetModules> loraTargetModules,
Map<String, Object> additionalProperties) {
this.earlyStoppingPatience = earlyStoppingPatience;
this.earlyStoppingThreshold = earlyStoppingThreshold;
this.trainBatchSize = trainBatchSize;
this.trainEpochs = trainEpochs;
this.learningRate = learningRate;
this.loraAlpha = loraAlpha;
this.loraRank = loraRank;
this.loraTargetModules = loraTargetModules;
this.additionalProperties = additionalProperties;
}

Expand Down Expand Up @@ -89,6 +101,32 @@ public Optional<Double> getLearningRate() {
return learningRate;
}

/**
* @return Controls the scaling factor for LoRA updates. Higher values make the
* updates more impactful.
*/
@JsonProperty("lora_alpha")
public Optional<Integer> getLoraAlpha() {
return loraAlpha;
}

/**
* @return Specifies the rank for low-rank matrices. Lower ranks reduce parameters
* but may limit model flexibility.
*/
@JsonProperty("lora_rank")
public Optional<Integer> getLoraRank() {
return loraRank;
}

/**
* @return The combination of LoRA modules to target.
*/
@JsonProperty("lora_target_modules")
public Optional<LoraTargetModules> getLoraTargetModules() {
return loraTargetModules;
}

@java.lang.Override
public boolean equals(Object other) {
if (this == other) return true;
Expand All @@ -105,7 +143,10 @@ private boolean equalTo(Hyperparameters other) {
&& earlyStoppingThreshold.equals(other.earlyStoppingThreshold)
&& trainBatchSize.equals(other.trainBatchSize)
&& trainEpochs.equals(other.trainEpochs)
&& learningRate.equals(other.learningRate);
&& learningRate.equals(other.learningRate)
&& loraAlpha.equals(other.loraAlpha)
&& loraRank.equals(other.loraRank)
&& loraTargetModules.equals(other.loraTargetModules);
}

@java.lang.Override
Expand All @@ -115,7 +156,10 @@ public int hashCode() {
this.earlyStoppingThreshold,
this.trainBatchSize,
this.trainEpochs,
this.learningRate);
this.learningRate,
this.loraAlpha,
this.loraRank,
this.loraTargetModules);
}

@java.lang.Override
Expand All @@ -139,6 +183,12 @@ public static final class Builder {

private Optional<Double> learningRate = Optional.empty();

private Optional<Integer> loraAlpha = Optional.empty();

private Optional<Integer> loraRank = Optional.empty();

private Optional<LoraTargetModules> loraTargetModules = Optional.empty();

@JsonAnySetter
private Map<String, Object> additionalProperties = new HashMap<>();

Expand All @@ -150,6 +200,9 @@ public Builder from(Hyperparameters other) {
trainBatchSize(other.getTrainBatchSize());
trainEpochs(other.getTrainEpochs());
learningRate(other.getLearningRate());
loraAlpha(other.getLoraAlpha());
loraRank(other.getLoraRank());
loraTargetModules(other.getLoraTargetModules());
return this;
}

Expand Down Expand Up @@ -208,13 +261,49 @@ public Builder learningRate(Double learningRate) {
return this;
}

@JsonSetter(value = "lora_alpha", nulls = Nulls.SKIP)
public Builder loraAlpha(Optional<Integer> loraAlpha) {
this.loraAlpha = loraAlpha;
return this;
}

public Builder loraAlpha(Integer loraAlpha) {
this.loraAlpha = Optional.of(loraAlpha);
return this;
}

@JsonSetter(value = "lora_rank", nulls = Nulls.SKIP)
public Builder loraRank(Optional<Integer> loraRank) {
this.loraRank = loraRank;
return this;
}

public Builder loraRank(Integer loraRank) {
this.loraRank = Optional.of(loraRank);
return this;
}

@JsonSetter(value = "lora_target_modules", nulls = Nulls.SKIP)
public Builder loraTargetModules(Optional<LoraTargetModules> loraTargetModules) {
this.loraTargetModules = loraTargetModules;
return this;
}

public Builder loraTargetModules(LoraTargetModules loraTargetModules) {
this.loraTargetModules = Optional.of(loraTargetModules);
return this;
}

public Hyperparameters build() {
return new Hyperparameters(
earlyStoppingPatience,
earlyStoppingThreshold,
trainBatchSize,
trainEpochs,
learningRate,
loraAlpha,
loraRank,
loraTargetModules,
additionalProperties);
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/**
* This file was auto-generated by Fern from our API Definition.
*/
package com.cohere.api.resources.finetuning.finetuning.types;

import com.fasterxml.jackson.annotation.JsonValue;

public enum LoraTargetModules {
LORA_TARGET_MODULES_UNSPECIFIED("LORA_TARGET_MODULES_UNSPECIFIED"),

LORA_TARGET_MODULES_QV("LORA_TARGET_MODULES_QV"),

LORA_TARGET_MODULES_QKVO("LORA_TARGET_MODULES_QKVO"),

LORA_TARGET_MODULES_QKVO_FFN("LORA_TARGET_MODULES_QKVO_FFN");

private final String value;

LoraTargetModules(String value) {
this.value = value;
}

@JsonValue
@java.lang.Override
public String toString() {
return this.value;
}
}
107 changes: 98 additions & 9 deletions src/main/java/com/cohere/api/resources/v2/V2Client.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import com.cohere.api.errors.CohereApiUnprocessableEntityError;
import com.cohere.api.resources.v2.requests.V2ChatRequest;
import com.cohere.api.resources.v2.requests.V2ChatStreamRequest;
import com.cohere.api.resources.v2.types.NonStreamedChatResponse2;
import com.cohere.api.resources.v2.types.StreamedChatResponse2;
import com.cohere.api.resources.v2.types.V2EmbedRequest;
import com.cohere.api.resources.v2.requests.V2EmbedRequest;
import com.cohere.api.resources.v2.requests.V2RerankRequest;
import com.cohere.api.resources.v2.types.V2RerankResponse;
import com.cohere.api.types.ChatResponse;
import com.cohere.api.types.ClientClosedRequestErrorBody;
import com.cohere.api.types.EmbedByTypeResponse;
import com.cohere.api.types.GatewayTimeoutErrorBody;
import com.cohere.api.types.NotImplementedErrorBody;
import com.cohere.api.types.StreamedChatResponseV2;
import com.cohere.api.types.TooManyRequestsErrorBody;
import com.cohere.api.types.UnprocessableEntityErrorBody;
import com.fasterxml.jackson.core.JsonProcessingException;
Expand All @@ -52,14 +54,14 @@ public V2Client(ClientOptions clientOptions) {
/**
* Generates a message from the model in response to a provided conversation. To learn how to use the Chat API with Streaming and RAG follow our Text Generation guides.
*/
public Iterable<StreamedChatResponse2> chatStream(V2ChatStreamRequest request) {
public Iterable<StreamedChatResponseV2> chatStream(V2ChatStreamRequest request) {
return chatStream(request, null);
}

/**
* Generates a message from the model in response to a provided conversation. To learn how to use the Chat API with Streaming and RAG follow our Text Generation guides.
*/
public Iterable<StreamedChatResponse2> chatStream(V2ChatStreamRequest request, RequestOptions requestOptions) {
public Iterable<StreamedChatResponseV2> chatStream(V2ChatStreamRequest request, RequestOptions requestOptions) {
HttpUrl httpUrl = HttpUrl.parse(this.clientOptions.environment().getUrl())
.newBuilder()
.addPathSegments("v2/chat")
Expand All @@ -84,7 +86,8 @@ public Iterable<StreamedChatResponse2> chatStream(V2ChatStreamRequest request, R
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
if (response.isSuccessful()) {
return new Stream<StreamedChatResponse2>(StreamedChatResponse2.class, responseBody.charStream(), "\n");
return new Stream<StreamedChatResponseV2>(
StreamedChatResponseV2.class, responseBody.charStream(), "\n");
}
String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
Expand Down Expand Up @@ -138,14 +141,14 @@ public Iterable<StreamedChatResponse2> chatStream(V2ChatStreamRequest request, R
/**
* Generates a message from the model in response to a provided conversation. To learn how to use the Chat API with Streaming and RAG follow our Text Generation guides.
*/
public NonStreamedChatResponse2 chat(V2ChatRequest request) {
public ChatResponse chat(V2ChatRequest request) {
return chat(request, null);
}

/**
* Generates a message from the model in response to a provided conversation. To learn how to use the Chat API with Streaming and RAG follow our Text Generation guides.
*/
public NonStreamedChatResponse2 chat(V2ChatRequest request, RequestOptions requestOptions) {
public ChatResponse chat(V2ChatRequest request, RequestOptions requestOptions) {
HttpUrl httpUrl = HttpUrl.parse(this.clientOptions.environment().getUrl())
.newBuilder()
.addPathSegments("v2/chat")
Expand All @@ -170,7 +173,7 @@ public NonStreamedChatResponse2 chat(V2ChatRequest request, RequestOptions reque
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
if (response.isSuccessful()) {
return ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), NonStreamedChatResponse2.class);
return ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), ChatResponse.class);
}
String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
Expand Down Expand Up @@ -310,4 +313,90 @@ public EmbedByTypeResponse embed(V2EmbedRequest request, RequestOptions requestO
throw new CohereApiError("Network error executing HTTP request", e);
}
}

/**
* This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
*/
public V2RerankResponse rerank(V2RerankRequest request) {
return rerank(request, null);
}

/**
* This endpoint takes in a query and a list of texts and produces an ordered array with each text assigned a relevance score.
*/
public V2RerankResponse rerank(V2RerankRequest request, RequestOptions requestOptions) {
HttpUrl httpUrl = HttpUrl.parse(this.clientOptions.environment().getUrl())
.newBuilder()
.addPathSegments("v2/rerank")
.build();
RequestBody body;
try {
body = RequestBody.create(
ObjectMappers.JSON_MAPPER.writeValueAsBytes(request), MediaTypes.APPLICATION_JSON);
} catch (JsonProcessingException e) {
throw new CohereApiError("Failed to serialize request", e);
}
Request okhttpRequest = new Request.Builder()
.url(httpUrl)
.method("POST", body)
.headers(Headers.of(clientOptions.headers(requestOptions)))
.addHeader("Content-Type", "application/json")
.build();
OkHttpClient client = clientOptions.httpClient();
if (requestOptions != null && requestOptions.getTimeout().isPresent()) {
client = clientOptions.httpClientWithTimeout(requestOptions);
}
try (Response response = client.newCall(okhttpRequest).execute()) {
ResponseBody responseBody = response.body();
if (response.isSuccessful()) {
return ObjectMappers.JSON_MAPPER.readValue(responseBody.string(), V2RerankResponse.class);
}
String responseBodyString = responseBody != null ? responseBody.string() : "{}";
try {
switch (response.code()) {
case 400:
throw new CohereApiBadRequestError(
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class));
case 401:
throw new CohereApiUnauthorizedError(
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class));
case 403:
throw new CohereApiForbiddenError(
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class));
case 404:
throw new CohereApiNotFoundError(
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class));
case 422:
throw new CohereApiUnprocessableEntityError(ObjectMappers.JSON_MAPPER.readValue(
responseBodyString, UnprocessableEntityErrorBody.class));
case 429:
throw new CohereApiTooManyRequestsError(ObjectMappers.JSON_MAPPER.readValue(
responseBodyString, TooManyRequestsErrorBody.class));
case 499:
throw new CohereApiClientClosedRequestError(ObjectMappers.JSON_MAPPER.readValue(
responseBodyString, ClientClosedRequestErrorBody.class));
case 500:
throw new CohereApiInternalServerError(
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class));
case 501:
throw new CohereApiNotImplementedError(
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, NotImplementedErrorBody.class));
case 503:
throw new CohereApiServiceUnavailableError(
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class));
case 504:
throw new CohereApiGatewayTimeoutError(
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, GatewayTimeoutErrorBody.class));
}
} catch (JsonProcessingException ignored) {
// unable to map error response, throwing generic error
}
throw new CohereApiApiError(
"Error with status code " + response.code(),
response.code(),
ObjectMappers.JSON_MAPPER.readValue(responseBodyString, Object.class));
} catch (IOException e) {
throw new CohereApiError("Network error executing HTTP request", e);
}
}
}
Loading

0 comments on commit 05d207a

Please sign in to comment.