Skip to content

Commit

Permalink
Use predict_vllm function for prediction.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 672561096
  • Loading branch information
Minwoo Park authored and copybara-github committed Sep 9, 2024
1 parent edeac27 commit 2add43b
Show file tree
Hide file tree
Showing 3 changed files with 110 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -672,6 +672,33 @@
" return model, endpoint\n",
"\n",
"\n",
"def predict_vllm(\n",
" prompt: str,\n",
" max_tokens: int,\n",
" temperature: float,\n",
" top_p: float,\n",
" top_k: int,\n",
" raw_response: bool,\n",
" lora_weight: str = \"\",\n",
"):\n",
" # Parameters for inference.\n",
" instance = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": temperature,\n",
" \"top_p\": top_p,\n",
" \"top_k\": top_k,\n",
" \"raw_response\": raw_response,\n",
" }\n",
" if lora_weight:\n",
" instance[\"dynamic-lora\"] = lora_weight\n",
" instances = [instance]\n",
" response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
"\n",
" for prediction in response.predictions:\n",
" print(prediction)\n",
"\n",
"\n",
"models[\"vllm_gpu\"], endpoints[\"vllm_gpu\"] = deploy_model_vllm(\n",
" model_name=common_util.get_job_name_with_datetime(prefix=\"llama3_1-serve\"),\n",
" model_id=model_id,\n",
Expand Down Expand Up @@ -708,8 +735,6 @@
"# @markdown Assistant: A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.\n",
"# @markdown ```\n",
"\n",
"# @markdown Optionally, you can apply LoRA weights to prediction. Set `lora_weight` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
"\n",
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
"\n",
"prompt = \"What is a car?\" # @param {type: \"string\"}\n",
Expand All @@ -719,24 +744,18 @@
"top_p = 1.0 # @param {type:\"number\"}\n",
"top_k = 1 # @param {type:\"integer\"}\n",
"raw_response = False # @param {type:\"boolean\"}\n",
"lora_weight = \"\" # @param {type:\"string\", isTemplate: true}\n",
"\n",
"# Overrides parameters for inferences.\n",
"instance = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": temperature,\n",
" \"top_p\": top_p,\n",
" \"top_k\": top_k,\n",
" \"raw_response\": raw_response,\n",
"}\n",
"if len(lora_weight) > 0:\n",
" instance[\"dynamic-lora\"] = lora_weight\n",
"instances = [instance]\n",
"response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
"# @markdown Optionally, you can apply LoRA weights to prediction. Set `lora_weight` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
"lora_weight = \"\" # @param {type:\"string\", isTemplate: true}\n",
"\n",
"for prediction in response.predictions:\n",
" print(prediction)\n",
"predict_vllm(\n",
" prompt=prompt,\n",
" max_tokens=max_tokens,\n",
" temperature=temperature,\n",
" top_p=top_p,\n",
" top_k=top_k,\n",
" lora_weight=lora_weight,\n",
")\n",
"\n",
"# @markdown Click \"Show Code\" to see more details."
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,33 @@
"\n",
" return model, endpoint\n",
"\n",
"\n",
"def predict_vllm(\n",
" prompt: str,\n",
" max_tokens: int,\n",
" temperature: float,\n",
" top_p: float,\n",
" top_k: int,\n",
" raw_response: bool,\n",
" lora_weight: str = \"\",\n",
"):\n",
" # Parameters for inference.\n",
" instance = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": temperature,\n",
" \"top_p\": top_p,\n",
" \"top_k\": top_k,\n",
" \"raw_response\": raw_response,\n",
" }\n",
" if lora_weight:\n",
" instance[\"dynamic-lora\"] = lora_weight\n",
" instances = [instance]\n",
" response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
"\n",
" for prediction in response.predictions:\n",
" print(prediction)\n",
"\n",
"models[\"vllm_gpu\"], endpoints[\"vllm_gpu\"] = deploy_model_vllm(\n",
" model_name=common_util.get_job_name_with_datetime(prefix=\"llama3_1-vllm-serve\"),\n",
" model_id=merged_model_output_dir,\n",
Expand Down Expand Up @@ -633,21 +660,8 @@
"# @markdown Human: What is a car?\n",
"# @markdown Assistant: A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.\n",
"# @markdown ```\n",
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
"\n",
"# Loads an existing endpoint instance using the endpoint name:\n",
"# - Using `endpoint_name = endpoint.name` allows us to get the\n",
"# endpoint name of the endpoint `endpoint` created in the cell\n",
"# above.\n",
"# - Alternatively, you can set `endpoint_name = \"1234567890123456789\"` to load\n",
"# an existing endpoint with the ID 1234567890123456789.\n",
"# You may uncomment the code below to load an existing endpoint.\n",
"\n",
"# endpoint_name = \"\" # @param {type:\"string\"}\n",
"# aip_endpoint_name = (\n",
"# f\"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}\"\n",
"# )\n",
"# endpoint = aiplatform.Endpoint(aip_endpoint_name)\n",
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
"\n",
"prompt = \"What is a car?\" # @param {type: \"string\"}\n",
"# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.\n",
Expand All @@ -657,22 +671,13 @@
"top_k = 1 # @param {type:\"integer\"}\n",
"raw_response = False # @param {type:\"boolean\"}\n",
"\n",
"# Overrides parameters for inferences.\n",
"instances = [\n",
" {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": temperature,\n",
" \"top_p\": top_p,\n",
" \"top_k\": top_k,\n",
" \"raw_response\": raw_response,\n",
" },\n",
"]\n",
"response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
"\n",
"for prediction in response.predictions:\n",
" print(prediction)\n",
"\n",
"predict_vllm(\n",
" prompt=prompt,\n",
" max_tokens=max_tokens,\n",
" temperature=temperature,\n",
" top_p=top_p,\n",
" top_k=top_k,\n",
")\n",
"# @markdown Click \"Show Code\" to see more details."
]
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,34 @@
" )\n",
" print(\"endpoint_name:\", endpoint.name)\n",
"\n",
" return model, endpoint"
" return model, endpoint\n",
"\n",
"\n",
"def predict_vllm(\n",
" prompt: str,\n",
" max_tokens: int,\n",
" temperature: float,\n",
" top_p: float,\n",
" top_k: int,\n",
" raw_response: bool,\n",
" lora_weight: str = \"\",\n",
"):\n",
" # Parameters for inference.\n",
" instance = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": temperature,\n",
" \"top_p\": top_p,\n",
" \"top_k\": top_k,\n",
" \"raw_response\": raw_response,\n",
" }\n",
" if lora_weight:\n",
" instance[\"dynamic-lora\"] = lora_weight\n",
" instances = [instance]\n",
" response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
"\n",
" for prediction in response.predictions:\n",
" print(prediction)"
]
},
{
Expand Down Expand Up @@ -426,8 +453,6 @@
"# @markdown Assistant: A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.\n",
"# @markdown ```\n",
"\n",
"# @markdown Optionally, you can apply LoRA weights to prediction. Set `lora_weight` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
"\n",
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
"\n",
"prompt = \"What is a car?\" # @param {type: \"string\"}\n",
Expand All @@ -437,24 +462,18 @@
"top_p = 1.0 # @param {type:\"number\"}\n",
"top_k = 1 # @param {type:\"integer\"}\n",
"raw_response = False # @param {type:\"boolean\"}\n",
"\n",
"# @markdown Optionally, you can apply LoRA weights to prediction. Set `lora_weight` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
"lora_weight = \"\" # @param {type:\"string\", isTemplate: true}\n",
"\n",
"# Overrides parameters for inferences.\n",
"instance = {\n",
" \"prompt\": prompt,\n",
" \"max_tokens\": max_tokens,\n",
" \"temperature\": temperature,\n",
" \"top_p\": top_p,\n",
" \"top_k\": top_k,\n",
" \"raw_response\": raw_response,\n",
"}\n",
"if len(lora_weight) > 0:\n",
" instance[\"dynamic-lora\"] = lora_weight\n",
"instances = [instance]\n",
"response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
"\n",
"for prediction in response.predictions:\n",
" print(prediction)\n",
"predict_vllm(\n",
" prompt=prompt,\n",
" max_tokens=max_tokens,\n",
" temperature=temperature,\n",
" top_p=top_p,\n",
" top_k=top_k,\n",
" lora_weight=lora_weight,\n",
")\n",
"\n",
"# @markdown Click \"Show Code\" to see more details."
]
Expand Down

0 comments on commit 2add43b

Please sign in to comment.