Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add CRaC Support #2

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 21 additions & 4 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<parent>
<groupId>org.springframework.boot</groupId>
<artifactId>spring-boot-starter-parent</artifactId>
<version>3.1.3</version>
<version>3.2.0</version>
<relativePath/> <!-- lookup parent from repository -->
</parent>
<groupId>com.apolloconfig.apollo.ai</groupId>
Expand All @@ -16,10 +16,13 @@
<description>a smart qa bot</description>
<properties>
<java.version>17</java.version>
<openai-gpt3-java.version>0.16.0</openai-gpt3-java.version>
<guava.version>32.1.2-jre</guava.version>
<openai-gpt3-java.version>0.18.2</openai-gpt3-java.version>
<guava.version>32.1.3-jre</guava.version>
<flexmark.version>0.64.8</flexmark.version>
<milvus.version>2.3.0</milvus.version>
<milvus.version>2.3.3</milvus.version>
<!-- There is a bug in 4.1.101.Final -->
<netty.codec.http2.version>4.1.100.Final</netty.codec.http2.version>
<crac.version>1.4.0</crac.version>
</properties>

<dependencyManagement>
Expand All @@ -44,6 +47,16 @@
<artifactId>milvus-sdk-java</artifactId>
<version>${milvus.version}</version>
</dependency>
<dependency>
<groupId>io.netty</groupId>
<artifactId>netty-codec-http2</artifactId>
<version>${netty.codec.http2.version}</version>
</dependency>
<dependency>
<groupId>org.crac</groupId>
<artifactId>crac</artifactId>
<version>${crac.version}</version>
</dependency>
</dependencies>
</dependencyManagement>

Expand Down Expand Up @@ -82,6 +95,10 @@
</exclusion>
</exclusions>
</dependency>
<dependency>
<groupId>org.crac</groupId>
<artifactId>crac</artifactId>
</dependency>

<dependency>
<groupId>org.springframework.boot</groupId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import io.milvus.client.MilvusServiceClient;
import io.milvus.param.ConnectParam;
import java.util.Map;
import org.crac.Context;
import org.crac.Resource;

public class MilvusClientFactory {
public class MilvusClientFactory implements Resource {

private static final MilvusClientFactory INSTANCE = new MilvusClientFactory();
private static final Map<String, MilvusServiceClient> clients = Maps.newConcurrentMap();
Expand Down Expand Up @@ -43,4 +45,14 @@ private MilvusServiceClient createClient(String host, int port) {
private MilvusServiceClient createCloudClient(String uri, String token) {
return new MilvusServiceClient(ConnectParam.newBuilder().withUri(uri).withToken(token).build());
}

@Override
public void beforeCheckpoint(Context<? extends Resource> context) throws Exception {
clients.clear();
}

@Override
public void afterRestore(Context<? extends Resource> context) throws Exception {

}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,30 @@
import java.util.List;
import java.util.Random;
import java.util.stream.Collectors;
import org.crac.Context;
import org.crac.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Service;
import org.springframework.util.CollectionUtils;

@Profile("milvus")
@Service
class MilvusService implements VectorDBService {
class MilvusService implements VectorDBService, Resource {

private final MilvusServiceClient milvusServiceClient;
private static final Logger LOGGER = LoggerFactory.getLogger(MilvusService.class);

private MilvusServiceClient milvusServiceClient;
private final MilvusConfig milvusConfig;
private final List<Float> dummyEmbeddings = Lists.newArrayList();

public MilvusService(MilvusConfig milvusConfig) {
this.milvusConfig = milvusConfig;
this.init();
}

private void init() {
if (milvusConfig.isUseZillzCloud()) {
this.milvusServiceClient = MilvusClientFactory.getCloudClient(
milvusConfig.getZillizCloudUri(),
Expand Down Expand Up @@ -413,4 +423,16 @@ private void ensureFileCollection() {
);
}

@Override
public void beforeCheckpoint(Context<? extends Resource> context) throws Exception {
LOGGER.info("beforeCheckpoint");
this.milvusServiceClient = null;
}

@Override
public void afterRestore(Context<? extends Resource> context) throws Exception {
LOGGER.info("afterRestore");
this.init();
LOGGER.info("afterRestore done");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,25 @@
import com.theokanning.openai.embedding.EmbeddingRequest;
import io.reactivex.Flowable;
import java.util.List;
import org.crac.Context;
import org.crac.Resource;
import org.springframework.context.annotation.Profile;
import org.springframework.stereotype.Component;

@Profile("openai")
@Component
class OpenAiService implements AiService {
class OpenAiService implements AiService, Resource {

private static final String DEFAULT_MODEL = "gpt-3.5-turbo";
private static final String DEFAULT_EMBEDDING_MODEL = "text-embedding-ada-002";

private final com.theokanning.openai.service.OpenAiService service;
private com.theokanning.openai.service.OpenAiService service;

public OpenAiService() {
init();
}

private void init() {
service = OpenAiServiceFactory.getService(System.getenv("OPENAI_API_KEY"));
}

Expand Down Expand Up @@ -60,4 +66,14 @@ public List<Embedding> getEmbeddings(List<String> chunks) {

return service.createEmbeddings(embeddingRequest).getData();
}

@Override
public void beforeCheckpoint(Context<? extends Resource> context) throws Exception {
this.service = null;
}

@Override
public void afterRestore(Context<? extends Resource> context) throws Exception {
this.init();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,11 @@
import okhttp3.Authenticator;
import okhttp3.Credentials;
import okhttp3.OkHttpClient;
import org.crac.Context;
import org.crac.Resource;
import retrofit2.Retrofit;

public class OpenAiServiceFactory {
public class OpenAiServiceFactory implements Resource {

private static final Duration DEFAULT_TIMEOUT = Duration.ofSeconds(60);

Expand Down Expand Up @@ -85,6 +87,16 @@ private OkHttpClient client(String apiKey) {
.build();
}

@Override
public void beforeCheckpoint(Context<? extends Resource> context) throws Exception {
SERVICES.clear();
}

@Override
public void afterRestore(Context<? extends Resource> context) throws Exception {

}

private static class DelegatingSocketFactory extends SocketFactory {

private final SocketFactory delegate;
Expand Down
14 changes: 14 additions & 0 deletions src/main/scripts/startup.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,27 @@ SERVICE_NAME=qa-bot
LOG_DIR=/opt/logs
## Adjust server port if necessary
SERVER_PORT=${SERVER_PORT:=9090}
## Adjust crac files dir if necessary
CRAC_FILES_DIR=/opt/crac

## Create log directory if not existed because JDK 8+ won't do that
mkdir -p $LOG_DIR

mkdir -p $CRAC_FILES_DIR

## Adjust memory settings if necessary
#export JAVA_OPTS="-Xms2560m -Xmx2560m -Xss256k -XX:MetaspaceSize=128m -XX:MaxMetaspaceSize=384m -XX:NewSize=1536m -XX:MaxNewSize=1536m -XX:SurvivorRatio=8"

# Check for 'checkpoint' argument
if [ "$1" = "checkpoint" ]; then
export JAVA_OPTS="$JAVA_OPTS -Dspring.context.checkpoint=onRefresh -XX:CRaCCheckpointTo=$CRAC_FILES_DIR"
fi

# Check for 'restore' argument
if [ "$1" = "restore" ]; then
export JAVA_OPTS="$JAVA_OPTS -XX:CRaCRestoreFrom=$CRAC_FILES_DIR"
fi

## Only uncomment the following when you are using server jvm
export JAVA_OPTS="$JAVA_OPTS -server -XX:-ReduceInitialCardMarks"

Expand Down