Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we prepare this dataset via script?

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
FROM lmsysorg/sglang:v0.5.2rc2-cu126
ENV BASE_MODEL nvidia/Llama-3.1-8B-Instruct-FP8
ENV DRAFT_MODEL lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B
ENV SGLANG_ARGS "--tp-size 1 --max-running-requests 32 --mem-fraction-static 0.8 --enable-torch-compile --speculative-algorithm EAGLE3 --speculative-num-steps 3 --speculative-eagle-topk 2 --speculative-num-draft-tokens 4 --dtype float16 --attention-backend fa3 --host 0.0.0.0 --port 30000"
ENV SGL_HOST 0.0.0.0
ENV SGL_PORT 30000
ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN 1

EXPOSE 30000
ENTRYPOINT python3 -m sglang.launch_server --model-path $BASE_MODEL --speculative-draft-model-path $DRAFT_MODEL $SGLANG_ARGS
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
Expand Up @@ -123,15 +123,17 @@
"import matplotlib.pyplot as plt\n",
"from scripts.utils import setup_workspace\n",
"from scripts.dataset import prepare_finqa_dataset\n",
"from scripts.run import get_run_metrics\n",
"from scripts.run import get_run_output_assetid, get_run_metrics\n",
"from scripts.reinforcement_learning import run_rl_training_pipeline\n",
"from scripts.evaluation import run_evaluation_pipeline\n",
"from scripts.speculative_decoding import (\n",
" run_draft_model_pipeline,\n",
" prepare_combined_model_for_deployment,\n",
" deploy_speculative_decoding_endpoint,\n",
" deploy_base_model_endpoint,\n",
" run_evaluation_speculative_decoding,\n",
")\n",
"from scripts.deployment import create_managed_deployment, test_deployment"
"from scripts.deployment import test_deployment"
]
},
{
Expand All @@ -150,7 +152,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<p>Prepare dataset for Finetuning. This would save train, test and valid dataset under data folder</p>"
"<p>Prepare dataset for Fine-tuning. This would save train, test and valid dataset under data folder</p>"
]
},
{
Expand Down Expand Up @@ -208,8 +210,8 @@
"metadata": {},
"outputs": [],
"source": [
"# Run complete RL training pipeline: train model, register model\n",
"grpo_job, status, grpo_registered_model = run_rl_training_pipeline(\n",
"# Run complete RL training pipeline: verify datasets, register data, train model, register model\n",
"grpo_job, status, registered_model = run_rl_training_pipeline(\n",
" ml_client=ml_client,\n",
" registry_ml_client=registry_ml_client,\n",
" base_model_id=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\", # Huggingface ID ot the model which is to be RFT finetuned.\n",
Expand Down Expand Up @@ -262,7 +264,7 @@
"outputs": [],
"source": [
"# Run complete RL training pipeline: verify datasets, register data, train model, register model\n",
"rlpp_job, status, rlpp_registered_model = run_rl_training_pipeline(\n",
"rlpp_job, status, registered_model = run_rl_training_pipeline(\n",
" ml_client=ml_client,\n",
" registry_ml_client=registry_ml_client,\n",
" base_model_id=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\", # Huggingface ID ot the model which is to be RFT finetuned.\n",
Expand Down Expand Up @@ -325,32 +327,47 @@
"metadata": {},
"outputs": [],
"source": [
"# Function which invokes the model evaluation pipeline.\n",
"eval_job, status = run_evaluation_pipeline(\n",
" ml_client=ml_client,\n",
" registry_ml_client=registry_ml_client,\n",
" compute_cluster=\"k8s-a100-compute\",\n",
" grpo_model_dir=grpo_registered_model.path, # Output from GPRO RL provided as data asset created from earlier step.\n",
" rlpp_model_dir=rlpp_registered_model.path, # Output from Reinforce_plus_plus RL provided as data asset created from earlier step.\n",
" validation_dataset_path=test_data_path, # Path to test dataset\n",
" run_config={\n",
" \"num_nodes\": 1, # Number of nodes to be used for evaluation run.\n",
" \"number_of_gpu_to_use\": 8, # Number of GPUs in a node to be used for evaluation run.\n",
" \"base_path_1_label\": \"GRPO\", # Label to identify GRPO model outputs.\n",
" \"base_path_2_label\": \"RLPP\", # Label to identify RLPP model outputs.\n",
" \"explore_pattern_1\": \"global_step_{checkpoint}/actor/lora_adapter/\",\n",
" \"explore_pattern_2\": \"global_step_{checkpoint}/actor/lora_adapter/\",\n",
" \"checkpoint_values_1\": \"12\",\n",
" \"checkpoint_values_2\": \"12\",\n",
" \"use_lora_adapters_1\": True,\n",
" \"use_lora_adapters_2\": True,\n",
" \"evaluate_base_model\": True, # Set to True to evaluate base model along with RL finetuned models.\n",
" \"hf_model_id\": \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\", # Huggingface ID of the base model\n",
" \"max_prompt_length\": 8196,\n",
" \"max_response_length\": 1024,\n",
" \"dtype\": \"bfloat16\",\n",
" \"tensor_parallel_size\": 4,\n",
" }, # Configuration parameters for evaluation run.\n",
"grpo_model_asset_id = get_run_output_assetid(\n",
" ml_client, job_name=grpo_job.name, output_name=\"model_output\"\n",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be taking older version of PR => #3709. Could we take a pull here please?

") # get model asset ID from the grpo training run\n",
"rlpp_model_asset_id = get_run_output_assetid(\n",
" ml_client, job_name=rlpp_job.name, output_name=\"model_output\"\n",
") # get model asset ID from the rlpp training run"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"eval_job, status = (\n",
" run_evaluation_pipeline( # Function which invokes the model evaluation pipeline.\n",
" ml_client=ml_client,\n",
" registry_ml_client=registry_ml_client,\n",
" compute_cluster=\"k8s-a100-compute\",\n",
" grpo_model_dir=grpo_model_asset_id, # Output from GPRO RL provided as data asset created from earlier step.\n",
" rlpp_model_dir=rlpp_model_asset_id, # Output from Reinforce_plus_plus RL provided as data asset created from earlier step.\n",
" validation_dataset_path=test_data_path, # Path to test dataset\n",
" run_config={ # Configuration to control base model, and also to point output location fo GRPO/RL++ runs.\n",
" \"num_nodes\": 1,\n",
" \"number_of_gpu_to_use\": 8,\n",
" \"base_path_1_label\": \"GRPO\",\n",
" \"base_path_2_label\": \"RLPP\",\n",
" \"explore_pattern_1\": \"global_step_{checkpoint}/actor/huggingface/\",\n",
" \"explore_pattern_2\": \"global_step_{checkpoint}/actor/lora_adapater/\",\n",
" \"checkpoint_values_1\": \"300,280,260,240\",\n",
" \"checkpoint_values_2\": \"264,260,240,220\",\n",
" \"use_lora_adapters_1\": False,\n",
" \"use_lora_adapters_2\": True,\n",
" \"evaluate_base_model\": True,\n",
" \"hf_model_id\": \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n",
" \"max_prompt_length\": 8196,\n",
" \"max_response_length\": 1024,\n",
" \"dtype\": \"bfloat16\",\n",
" \"tensor_parallel_size\": 4,\n",
" },\n",
" )\n",
")"
]
},
Expand Down Expand Up @@ -498,7 +515,7 @@
" num_epochs=1, # Number of train epochs to be run by draft trainer.\n",
" monitor=False, # Set to True to wait for completion.\n",
" base_model_mlflow_path=\"azureml://registries/azureml-meta/models/Meta-Llama-3-8B-Instruct/versions/9\",\n",
" draft_train_data_path=\"./data_for_draft_model/train/sharegpt_train_small.jsonl\",\n",
" draft_train_data_path=\"./data/draft_model/sharegpt_train_small.jsonl\",\n",
")"
]
},
Expand Down Expand Up @@ -591,8 +608,7 @@
"endpoint_name = deploy_speculative_decoding_endpoint(\n",
" ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n",
" combined_model=combined_model, # Reference from previous steps where combined model is created.\n",
" instance_type=\"octagepu\", # Instance type Kubernetes Cluster\n",
" compute_name=\"k8s-a100-compute\",\n",
" instance_type=\"Standard_NC40ads_H100_v5\", # Instance type\n",
")"
]
},
Expand Down Expand Up @@ -631,10 +647,9 @@
"outputs": [],
"source": [
"# Deploy managed online endpoint with base model\n",
"base_endpoint_name = create_managed_deployment( # Function to create endpoint for base model.\n",
"base_endpoint_name = deploy_base_model_endpoint( # Function to create endpoint for base model.\n",
" ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n",
" model_asset_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # Huggingface ID of the base model.\n",
" instance_type=\"Standard_ND96amsr_A100_v4\", # Compute SKU on which base model will be deployed.\n",
" instance_type=\"Standard_NC40ads_H100_v5\", # Compute SKU on which base model will be deployed.\n",
")"
]
},
Expand Down Expand Up @@ -711,10 +726,12 @@
"# Run evaluation job to compare base model and speculative decoding endpoints' performance\n",
"evaluation_job = run_evaluation_speculative_decoding(\n",
" ml_client=ml_client,\n",
" registry_ml_client=registry_ml_client,\n",
" base_endpoint_name=base_endpoint_name, # Base model endpoint from previous step.\n",
" speculative_endpoint_name=endpoint_name, # Speculative endpoint from previous step.\n",
" base_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n",
" speculative_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n",
" base_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n",
" speculative_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n",
" compute_cluster=\"d13-v2\",\n",
")"
]
},
Expand All @@ -735,7 +752,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"<img src=\"metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">"
"<img src=\"./images/metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">"
]
}
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def create_managed_deployment(
ml_client: MLClient,
model_asset_id: str, # Asset ID of the model to deploy
instance_type: str, # Supported instance type for managed deployment
model_mount_path: Optional[str] = None,
environment_asset_id: Optional[str] = None, # Asset ID of the serving engine to use
endpoint_name: Optional[str] = None,
endpoint_description: str = "Sample endpoint",
Expand Down Expand Up @@ -65,6 +66,7 @@ def create_managed_deployment(
name=deployment_name,
endpoint_name=endpoint_name,
model=model_asset_id,
model_mount_path=model_mount_path,
instance_type=instance_type,
instance_count=1,
environment=environment_asset_id,
Expand Down Expand Up @@ -151,7 +153,10 @@ def test_deployment(ml_client, endpoint_name):
"""Run a test request against a deployed endpoint and print the result."""
print("Testing endpoint...")
# Retrieve endpoint URI and API key to authenticate test request
scoring_uri = ml_client.online_endpoints.get(endpoint_name).scoring_uri
scoring_uri = (
ml_client.online_endpoints.get(endpoint_name).scoring_uri.replace("/score", "/")
+ "v1/chat/completions"
)
if not scoring_uri:
raise ValueError("Scoring URI not found for endpoint.")

Expand Down
Loading