Skip to content

Commit 2add43b

Browse files
Minwoo Parkcopybara-github
authored andcommitted
Use predict_vllm function for prediction.
PiperOrigin-RevId: 672561096
1 parent edeac27 commit 2add43b

File tree

3 files changed

+110
-67
lines changed

3 files changed

+110
-67
lines changed

notebooks/community/model_garden/model_garden_pytorch_llama3_1_deployment.ipynb

Lines changed: 37 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -672,6 +672,33 @@
672672
" return model, endpoint\n",
673673
"\n",
674674
"\n",
675+
"def predict_vllm(\n",
676+
" prompt: str,\n",
677+
" max_tokens: int,\n",
678+
" temperature: float,\n",
679+
" top_p: float,\n",
680+
" top_k: int,\n",
681+
" raw_response: bool,\n",
682+
" lora_weight: str = \"\",\n",
683+
"):\n",
684+
" # Parameters for inference.\n",
685+
" instance = {\n",
686+
" \"prompt\": prompt,\n",
687+
" \"max_tokens\": max_tokens,\n",
688+
" \"temperature\": temperature,\n",
689+
" \"top_p\": top_p,\n",
690+
" \"top_k\": top_k,\n",
691+
" \"raw_response\": raw_response,\n",
692+
" }\n",
693+
" if lora_weight:\n",
694+
" instance[\"dynamic-lora\"] = lora_weight\n",
695+
" instances = [instance]\n",
696+
" response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
697+
"\n",
698+
" for prediction in response.predictions:\n",
699+
" print(prediction)\n",
700+
"\n",
701+
"\n",
675702
"models[\"vllm_gpu\"], endpoints[\"vllm_gpu\"] = deploy_model_vllm(\n",
676703
" model_name=common_util.get_job_name_with_datetime(prefix=\"llama3_1-serve\"),\n",
677704
" model_id=model_id,\n",
@@ -708,8 +735,6 @@
708735
"# @markdown Assistant: A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.\n",
709736
"# @markdown ```\n",
710737
"\n",
711-
"# @markdown Optionally, you can apply LoRA weights to prediction. Set `lora_weight` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
712-
"\n",
713738
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
714739
"\n",
715740
"prompt = \"What is a car?\" # @param {type: \"string\"}\n",
@@ -719,24 +744,18 @@
719744
"top_p = 1.0 # @param {type:\"number\"}\n",
720745
"top_k = 1 # @param {type:\"integer\"}\n",
721746
"raw_response = False # @param {type:\"boolean\"}\n",
722-
"lora_weight = \"\" # @param {type:\"string\", isTemplate: true}\n",
723747
"\n",
724-
"# Overrides parameters for inferences.\n",
725-
"instance = {\n",
726-
" \"prompt\": prompt,\n",
727-
" \"max_tokens\": max_tokens,\n",
728-
" \"temperature\": temperature,\n",
729-
" \"top_p\": top_p,\n",
730-
" \"top_k\": top_k,\n",
731-
" \"raw_response\": raw_response,\n",
732-
"}\n",
733-
"if len(lora_weight) > 0:\n",
734-
" instance[\"dynamic-lora\"] = lora_weight\n",
735-
"instances = [instance]\n",
736-
"response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
748+
"# @markdown Optionally, you can apply LoRA weights to prediction. Set `lora_weight` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
749+
"lora_weight = \"\" # @param {type:\"string\", isTemplate: true}\n",
737750
"\n",
738-
"for prediction in response.predictions:\n",
739-
" print(prediction)\n",
751+
"predict_vllm(\n",
752+
" prompt=prompt,\n",
753+
" max_tokens=max_tokens,\n",
754+
" temperature=temperature,\n",
755+
" top_p=top_p,\n",
756+
" top_k=top_k,\n",
757+
" lora_weight=lora_weight,\n",
758+
")\n",
740759
"\n",
741760
"# @markdown Click \"Show Code\" to see more details."
742761
]

notebooks/community/model_garden/model_garden_pytorch_llama3_1_finetuning.ipynb

Lines changed: 35 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -600,6 +600,33 @@
600600
"\n",
601601
" return model, endpoint\n",
602602
"\n",
603+
"\n",
604+
"def predict_vllm(\n",
605+
" prompt: str,\n",
606+
" max_tokens: int,\n",
607+
" temperature: float,\n",
608+
" top_p: float,\n",
609+
" top_k: int,\n",
610+
" raw_response: bool,\n",
611+
" lora_weight: str = \"\",\n",
612+
"):\n",
613+
" # Parameters for inference.\n",
614+
" instance = {\n",
615+
" \"prompt\": prompt,\n",
616+
" \"max_tokens\": max_tokens,\n",
617+
" \"temperature\": temperature,\n",
618+
" \"top_p\": top_p,\n",
619+
" \"top_k\": top_k,\n",
620+
" \"raw_response\": raw_response,\n",
621+
" }\n",
622+
" if lora_weight:\n",
623+
" instance[\"dynamic-lora\"] = lora_weight\n",
624+
" instances = [instance]\n",
625+
" response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
626+
"\n",
627+
" for prediction in response.predictions:\n",
628+
" print(prediction)\n",
629+
"\n",
603630
"models[\"vllm_gpu\"], endpoints[\"vllm_gpu\"] = deploy_model_vllm(\n",
604631
" model_name=common_util.get_job_name_with_datetime(prefix=\"llama3_1-vllm-serve\"),\n",
605632
" model_id=merged_model_output_dir,\n",
@@ -633,21 +660,8 @@
633660
"# @markdown Human: What is a car?\n",
634661
"# @markdown Assistant: A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.\n",
635662
"# @markdown ```\n",
636-
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
637-
"\n",
638-
"# Loads an existing endpoint instance using the endpoint name:\n",
639-
"# - Using `endpoint_name = endpoint.name` allows us to get the\n",
640-
"# endpoint name of the endpoint `endpoint` created in the cell\n",
641-
"# above.\n",
642-
"# - Alternatively, you can set `endpoint_name = \"1234567890123456789\"` to load\n",
643-
"# an existing endpoint with the ID 1234567890123456789.\n",
644-
"# You may uncomment the code below to load an existing endpoint.\n",
645663
"\n",
646-
"# endpoint_name = \"\" # @param {type:\"string\"}\n",
647-
"# aip_endpoint_name = (\n",
648-
"# f\"projects/{PROJECT_ID}/locations/{REGION}/endpoints/{endpoint_name}\"\n",
649-
"# )\n",
650-
"# endpoint = aiplatform.Endpoint(aip_endpoint_name)\n",
664+
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
651665
"\n",
652666
"prompt = \"What is a car?\" # @param {type: \"string\"}\n",
653667
"# @markdown If you encounter the issue like `ServiceUnavailable: 503 Took too long to respond when processing`, you can reduce the maximum number of output tokens, such as set `max_tokens` as 20.\n",
@@ -657,22 +671,13 @@
657671
"top_k = 1 # @param {type:\"integer\"}\n",
658672
"raw_response = False # @param {type:\"boolean\"}\n",
659673
"\n",
660-
"# Overrides parameters for inferences.\n",
661-
"instances = [\n",
662-
" {\n",
663-
" \"prompt\": prompt,\n",
664-
" \"max_tokens\": max_tokens,\n",
665-
" \"temperature\": temperature,\n",
666-
" \"top_p\": top_p,\n",
667-
" \"top_k\": top_k,\n",
668-
" \"raw_response\": raw_response,\n",
669-
" },\n",
670-
"]\n",
671-
"response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
672-
"\n",
673-
"for prediction in response.predictions:\n",
674-
" print(prediction)\n",
675-
"\n",
674+
"predict_vllm(\n",
675+
" prompt=prompt,\n",
676+
" max_tokens=max_tokens,\n",
677+
" temperature=temperature,\n",
678+
" top_p=top_p,\n",
679+
" top_k=top_k,\n",
680+
")\n",
676681
"# @markdown Click \"Show Code\" to see more details."
677682
]
678683
},

notebooks/community/model_garden/model_garden_pytorch_llama3_deployment.ipynb

Lines changed: 38 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -279,7 +279,34 @@
279279
" )\n",
280280
" print(\"endpoint_name:\", endpoint.name)\n",
281281
"\n",
282-
" return model, endpoint"
282+
" return model, endpoint\n",
283+
"\n",
284+
"\n",
285+
"def predict_vllm(\n",
286+
" prompt: str,\n",
287+
" max_tokens: int,\n",
288+
" temperature: float,\n",
289+
" top_p: float,\n",
290+
" top_k: int,\n",
291+
" raw_response: bool,\n",
292+
" lora_weight: str = \"\",\n",
293+
"):\n",
294+
" # Parameters for inference.\n",
295+
" instance = {\n",
296+
" \"prompt\": prompt,\n",
297+
" \"max_tokens\": max_tokens,\n",
298+
" \"temperature\": temperature,\n",
299+
" \"top_p\": top_p,\n",
300+
" \"top_k\": top_k,\n",
301+
" \"raw_response\": raw_response,\n",
302+
" }\n",
303+
" if lora_weight:\n",
304+
" instance[\"dynamic-lora\"] = lora_weight\n",
305+
" instances = [instance]\n",
306+
" response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
307+
"\n",
308+
" for prediction in response.predictions:\n",
309+
" print(prediction)"
283310
]
284311
},
285312
{
@@ -426,8 +453,6 @@
426453
"# @markdown Assistant: A car, or a motor car, is a road-connected human-transportation system used to move people or goods from one place to another. The term also encompasses a wide range of vehicles, including motorboats, trains, and aircrafts. Cars typically have four wheels, a cabin for passengers, and an engine or motor. They have been around since the early 19th century and are now one of the most popular forms of transportation, used for daily commuting, shopping, and other purposes.\n",
427454
"# @markdown ```\n",
428455
"\n",
429-
"# @markdown Optionally, you can apply LoRA weights to prediction. Set `lora_weight` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
430-
"\n",
431456
"# @markdown Additionally, you can moderate the generated text with Vertex AI. See [Moderate text documentation](https://cloud.google.com/natural-language/docs/moderating-text) for more details.\n",
432457
"\n",
433458
"prompt = \"What is a car?\" # @param {type: \"string\"}\n",
@@ -437,24 +462,18 @@
437462
"top_p = 1.0 # @param {type:\"number\"}\n",
438463
"top_k = 1 # @param {type:\"integer\"}\n",
439464
"raw_response = False # @param {type:\"boolean\"}\n",
465+
"\n",
466+
"# @markdown Optionally, you can apply LoRA weights to prediction. Set `lora_weight` to be either a GCS URI or a HuggingFace repo containing the LoRA weight.\n",
440467
"lora_weight = \"\" # @param {type:\"string\", isTemplate: true}\n",
441468
"\n",
442-
"# Overrides parameters for inferences.\n",
443-
"instance = {\n",
444-
" \"prompt\": prompt,\n",
445-
" \"max_tokens\": max_tokens,\n",
446-
" \"temperature\": temperature,\n",
447-
" \"top_p\": top_p,\n",
448-
" \"top_k\": top_k,\n",
449-
" \"raw_response\": raw_response,\n",
450-
"}\n",
451-
"if len(lora_weight) > 0:\n",
452-
" instance[\"dynamic-lora\"] = lora_weight\n",
453-
"instances = [instance]\n",
454-
"response = endpoints[\"vllm_gpu\"].predict(instances=instances)\n",
455-
"\n",
456-
"for prediction in response.predictions:\n",
457-
" print(prediction)\n",
469+
"predict_vllm(\n",
470+
" prompt=prompt,\n",
471+
" max_tokens=max_tokens,\n",
472+
" temperature=temperature,\n",
473+
" top_p=top_p,\n",
474+
" top_k=top_k,\n",
475+
" lora_weight=lora_weight,\n",
476+
")\n",
458477
"\n",
459478
"# @markdown Click \"Show Code\" to see more details."
460479
]

0 commit comments

Comments
 (0)