Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -46,5 +46,23 @@ RUN --mount=type=secret,id=HF_TOKEN,required=false \
python3 /src/download_model.py; \
fi

# Customisations for LoRA adapters in wasabi
RUN --mount=type=cache,target=/root/.cache/pip \
python3 -m pip install boto3

ARG WASABI_LORA_ADAPTER_PATH
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what is WASABI?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#169 is a better fix for same issue - hence closing


ENV WASABI_LORA_ADAPTER_PATH=$WASABI_LORA_ADAPTER_PATH
#ENV AWS_ACCESS_KEY_ID=$AWS_ACCESS_KEY_ID
#ENV AWS_SECRET_ACCESS_KEY=$AWS_SECRET_ACCESS_KEY

RUN mkdir /model_adapter
# using credentials file
RUN --mount=type=secret,id=credentials,target=/root/.aws/credentials \
python3 /src/download_lora_adapter.py

# using secrets injected as ENV vars
# RUN python3 /src/download_lora_adapter.py

# Start the handler
CMD ["python3", "/src/handler.py"]
54 changes: 54 additions & 0 deletions src/download_lora_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import os
from pathlib import Path
from urllib.parse import urlparse

import boto3

# Use the following code to connect using Wasabi profile from .aws/credentials file
session = boto3.Session(profile_name="default")
credentials = session.get_credentials()
#
aws_access_key_id = credentials.access_key
aws_secret_access_key = credentials.secret_key

# aws_access_key_id = os.getenv("WASABI_ACCESS_KEY")
# aws_secret_access_key = os.getenv("WASABI_SECRET_ACCESS_KEY")
adapter_path = os.getenv("WASABI_LORA_ADAPTER_PATH")

print("WASABI_ACCESS_KEY: ", aws_access_key_id)
print("WASABI_SECRET_ACCESS_KEY: ", aws_secret_access_key)
print("WASABI_LORA_ADAPTER_PATH: ", adapter_path)

# Endpoint is determined when bucket is created
ENDPOINT_URL = 'https://s3.eu-west-1.wasabisys.com'

s3 = boto3.client('s3',
endpoint_url=ENDPOINT_URL, # s3.wasabisys.com ?
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key)


def download_s3_folder(s3_uri, local_dir=None):
"""
Download the contents of a folder directory
Args:
s3_uri: the s3 uri to the top level of the files you wish to download
local_dir: a relative or absolute directory path in the local file system
"""
s3 = boto3.resource("s3",
endpoint_url=ENDPOINT_URL,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key)
bucket = s3.Bucket(urlparse(s3_uri).hostname)
s3_path = urlparse(s3_uri).path.lstrip('/')
if local_dir is not None:
local_dir = Path(local_dir)
for obj in bucket.objects.filter(Prefix=s3_path):
target = Path(obj.key) if local_dir is None else local_dir / Path(obj.key).relative_to(s3_path)
target.parent.mkdir(parents=True, exist_ok=True)
if obj.key[-1] == '/':
continue
bucket.download_file(obj.key, str(target))

if __name__ == "__main__":
download_s3_folder(adapter_path)
4 changes: 3 additions & 1 deletion src/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,10 +143,12 @@ async def _initialize_engines(self):
engine_client=self.llm,
model_config=self.model_config,
base_model_paths=self.base_model_paths,
lora_modules=None,
lora_modules=lora_modules,
prompt_adapters=None,
)

await self.serving_models.init_static_loras()

self.chat_engine = OpenAIServingChat(
engine_client=self.llm,
model_config=self.model_config,
Expand Down
9 changes: 7 additions & 2 deletions src/handler.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import os
import logging

import runpod
from utils import JobInput

from engine import vLLMEngine, OpenAIvLLMEngine
from utils import JobInput

log = logging.getLogger(__name__)

vllm_engine = vLLMEngine()
OpenAIvLLMEngine = OpenAIvLLMEngine(vllm_engine)

async def handler(job):
log.info("handle(job=%s)", job)
job_input = JobInput(job["input"])
engine = OpenAIvLLMEngine if job_input.openai_route else vllm_engine
results_generator = engine.generate(job_input)
Expand Down