Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LoRA serving to mistral. #3555

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
" - [mistralai/Mistral-7B-Instruct-v0.2](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2): improved instruction fine-tuned version of Mistral-7B-Instruct-v0.1 supporting 32k context length\n",
" - [mistralai/Mistral-7B-v0.3](https://huggingface.co/mistralai/Mistral-7B-v0.3): Mistral-7B-v0.2 with extended vocabulary of 32768 and supports function calling\n",
" - [mistralai/Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3): instruction fine-tuned version of the Mistral-7B-v0.3 generative text model\n",
" - [mistralai/Mistral-Nemo-Base-2407](https://huggingface.co/mistralai/Mistral-Nemo-Base-2407): pretrained generative text model of 12B parameters \n",
" - [mistralai/Mistral-Nemo-Base-2407](https://huggingface.co/mistralai/Mistral-Nemo-Base-2407): pretrained generative text model of 12B parameters\n",
" - [mistralai/Mistral-Nemo-Instruct-2407](https://huggingface.co/mistralai/Mistral-Nemo-Instruct-2407): instruct fine-tuned version of the Mistral-Nemo-Base-2407\n",
"\n",
"### Costs\n",
Expand Down Expand Up @@ -187,11 +187,11 @@
"\n",
"# @markdown Set the model to deploy.\n",
"\n",
"prebuilt_model_id = \"mistralai/Mistral-7B-Instruct-v0.3\" # @param [\"mistralai/Mistral-7B-v0.1\", \"mistralai/Mistral-7B-Instruct-v0.1\", \"mistralai/Mistral-7B-Instruct-v0.2\", \"mistralai/Mistral-7B-v0.3\", \"mistralai/Mistral-7B-Instruct-v0.3\", \"mistralai/Mistral-Nemo-Base-2407\", \"mistralai/Mistral-Nemo-Instruct-2407\"]\n",
"prebuilt_model_id = \"mistralai/Mistral-7B-Instruct-v0.3\" # @param [\"mistralai/Mistral-7B-v0.1\", \"mistralai/Mistral-7B-Instruct-v0.1\", \"mistralai/Mistral-7B-Instruct-v0.2\", \"mistralai/Mistral-7B-v0.3\", \"mistralai/Mistral-7B-Instruct-v0.3\", \"mistralai/Mistral-Nemo-Base-2407\", \"mistralai/Mistral-Nemo-Instruct-2407\"] {isTemplate: true}\n",
"model_id = f\"gs://vertex-model-garden-public-us/{prebuilt_model_id}\"\n",
"\n",
"# The pre-built serving docker image for vLLM.\n",
"VLLM_DOCKER_URI = \"us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20240721_0916_RC00\"\n",
"VLLM_DOCKER_URI = \"us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:us-docker.pkg.dev/vertex-ai/vertex-vision-model-garden-dockers/pytorch-vllm-serve:20240912_0916_RC00\"\n",
"\n",
"# Find Vertex AI prediction supported accelerators and regions in\n",
"# https://cloud.google.com/vertex-ai/docs/predictions/configure-compute.\n",
Expand Down Expand Up @@ -253,6 +253,7 @@
" enable_trust_remote_code: bool = False,\n",
" enforce_eager: bool = False,\n",
" enable_lora: bool = False,\n",
" max_lora_rank: int = 16,\n",
" max_loras: int = 1,\n",
" max_cpu_loras: int = 8,\n",
") -> Tuple[aiplatform.Model, aiplatform.Endpoint]:\n",
Expand All @@ -275,6 +276,7 @@
" f\"--gpu-memory-utilization={gpu_memory_utilization}\",\n",
" f\"--max-model-len={max_model_len}\",\n",
" f\"--dtype={dtype}\",\n",
" f\"--max-lora-rank={max_lora_rank}\",\n",
" f\"--max-loras={max_loras}\",\n",
" f\"--max-cpu-loras={max_cpu_loras}\",\n",
" \"--disable-log-stats\",\n",
Expand Down Expand Up @@ -338,6 +340,8 @@
" max_model_len=max_model_len,\n",
" gpu_memory_utilization=gpu_memory_utilization,\n",
" dtype=dtype,\n",
" enable_lora=True,\n",
" max_lora_rank=64,\n",
")"
]
},
Expand All @@ -358,23 +362,10 @@
"\n",
"# @markdown ```\n",
"# @markdown Human: What is a car?\n",
"# @markdown Assistant: A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.\n",
"# @markdown Assistant: A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another.\n",
"# @markdown ```\n",
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
"\n",
"# Loads an existing endpoint instance using the endpoint name:\n",
"# - Using `endpoint_name = endpoint.name` allows us to get the\n",
"# endpoint name of the endpoint `endpoint` created in the cell\n",
"# above.\n",
"# - Alternatively, you can set `endpoint_name = \"1234567890123456789\"` to load\n",
"# an existing endpoint with the ID 1234567890123456789.\n",
"# You may uncomment the code below to load an existing endpoint.\n",
"\n",
"# endpoint_name = \"\" # @param {type:\"string\"}\n",
"# aip_endpoint_name = (\n",
"# f\"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}\"\n",
"# )\n",
"# endpoint = aiplatform.Endpoint(aip_endpoint_name)\n",
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
"\n",
"prompt = \"What is a car?\" # @param {type: \"string\"}\n",
"# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, by lowering `max_tokens`.\n",
Expand All @@ -383,22 +374,27 @@
"top_p = 1.0 # @param {type:\"number\"}\n",
"top_k = 1 # @param {type:\"integer\"}\n",
"raw_response = False # @param {type:\"boolean\"}\n",
"# @markdown Optionally, you can apply LoRA weights to on a per-request basis. Set `lora_id` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
"lora_id = \"\" # @param {type:\"string\", isTemplate: true}\n",
"\n",
"# Overrides parameters for inferences.\n",
"instances = [\n",
" {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": temperature,\n",
" \"top_p\": top_p,\n",
" \"top_k\": top_k,\n",
" \"raw_response\": raw_response,\n",
" },\n",
"]\n",
"instance = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": temperature,\n",
" \"top_p\": top_p,\n",
" \"top_k\": top_k,\n",
" \"raw_response\": raw_response,\n",
"}\n",
"if lora_id:\n",
" instance[\"dynamic-lora\"] = lora_id\n",
"instances = [instance]\n",
"response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
"\n",
"for prediction in response.predictions:\n",
" print(prediction)\n",
"# @markdown For example, for Mistral-7B deployments, you can set `lora_id` to\n",
"# @markdown [`uukuguy/Mistral-7B-OpenOrca-lora`](https://huggingface.co/uukuguy/Mistral-7B-OpenOrca-lora) to use the LoRA weights from the codealpaca repo.\n",
"\n",
"# Reference the following code for using the OpenAI vLLM server.\n",
"# import json\n",
Expand Down