-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Add speculative decoding part #3711
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
iamrk04
wants to merge
3
commits into
ayushmishra/add-nb-to-submit-rl-jobs
Choose a base branch
from
iamrk04/speculative_decoding
base: ayushmishra/add-nb-to-submit-rl-jobs
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
500 changes: 500 additions & 0 deletions
500
...undation-models/system/reinforcement-learning/data/draft_model/sharegpt_train_small.jsonl
Large diffs are not rendered by default.
Oops, something went wrong.
10 changes: 10 additions & 0 deletions
10
...tion-models/system/reinforcement-learning/environment/speculative-decoding-env/Dockerfile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| FROM lmsysorg/sglang:v0.5.2rc2-cu126 | ||
| ENV BASE_MODEL nvidia/Llama-3.1-8B-Instruct-FP8 | ||
| ENV DRAFT_MODEL lmsys/sglang-EAGLE3-LLaMA3.1-Instruct-8B | ||
| ENV SGLANG_ARGS "--tp-size 1 --max-running-requests 32 --mem-fraction-static 0.8 --enable-torch-compile --speculative-algorithm EAGLE3 --speculative-num-steps 3 --speculative-eagle-topk 2 --speculative-num-draft-tokens 4 --dtype float16 --attention-backend fa3 --host 0.0.0.0 --port 30000" | ||
| ENV SGL_HOST 0.0.0.0 | ||
| ENV SGL_PORT 30000 | ||
| ENV SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN 1 | ||
|
|
||
| EXPOSE 30000 | ||
| ENTRYPOINT python3 -m sglang.launch_server --model-path $BASE_MODEL --speculative-draft-model-path $DRAFT_MODEL $SGLANG_ARGS |
Binary file added
BIN
+55 KB
...on-models/system/reinforcement-learning/images/metrics-base-target-spec-dec.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -123,15 +123,17 @@ | |
| "import matplotlib.pyplot as plt\n", | ||
| "from scripts.utils import setup_workspace\n", | ||
| "from scripts.dataset import prepare_finqa_dataset\n", | ||
| "from scripts.run import get_run_metrics\n", | ||
| "from scripts.run import get_run_output_assetid, get_run_metrics\n", | ||
| "from scripts.reinforcement_learning import run_rl_training_pipeline\n", | ||
| "from scripts.evaluation import run_evaluation_pipeline\n", | ||
| "from scripts.speculative_decoding import (\n", | ||
| " run_draft_model_pipeline,\n", | ||
| " prepare_combined_model_for_deployment,\n", | ||
| " deploy_speculative_decoding_endpoint,\n", | ||
| " deploy_base_model_endpoint,\n", | ||
| " run_evaluation_speculative_decoding,\n", | ||
| ")\n", | ||
| "from scripts.deployment import create_managed_deployment, test_deployment" | ||
| "from scripts.deployment import test_deployment" | ||
| ] | ||
| }, | ||
| { | ||
|
|
@@ -150,7 +152,7 @@ | |
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "<p>Prepare dataset for Finetuning. This would save train, test and valid dataset under data folder</p>" | ||
| "<p>Prepare dataset for Fine-tuning. This would save train, test and valid dataset under data folder</p>" | ||
| ] | ||
| }, | ||
| { | ||
|
|
@@ -208,8 +210,8 @@ | |
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Run complete RL training pipeline: train model, register model\n", | ||
| "grpo_job, status, grpo_registered_model = run_rl_training_pipeline(\n", | ||
| "# Run complete RL training pipeline: verify datasets, register data, train model, register model\n", | ||
| "grpo_job, status, registered_model = run_rl_training_pipeline(\n", | ||
| " ml_client=ml_client,\n", | ||
| " registry_ml_client=registry_ml_client,\n", | ||
| " base_model_id=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\", # Huggingface ID ot the model which is to be RFT finetuned.\n", | ||
|
|
@@ -262,7 +264,7 @@ | |
| "outputs": [], | ||
| "source": [ | ||
| "# Run complete RL training pipeline: verify datasets, register data, train model, register model\n", | ||
| "rlpp_job, status, rlpp_registered_model = run_rl_training_pipeline(\n", | ||
| "rlpp_job, status, registered_model = run_rl_training_pipeline(\n", | ||
| " ml_client=ml_client,\n", | ||
| " registry_ml_client=registry_ml_client,\n", | ||
| " base_model_id=\"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\", # Huggingface ID ot the model which is to be RFT finetuned.\n", | ||
|
|
@@ -325,32 +327,47 @@ | |
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "# Function which invokes the model evaluation pipeline.\n", | ||
| "eval_job, status = run_evaluation_pipeline(\n", | ||
| " ml_client=ml_client,\n", | ||
| " registry_ml_client=registry_ml_client,\n", | ||
| " compute_cluster=\"k8s-a100-compute\",\n", | ||
| " grpo_model_dir=grpo_registered_model.path, # Output from GPRO RL provided as data asset created from earlier step.\n", | ||
| " rlpp_model_dir=rlpp_registered_model.path, # Output from Reinforce_plus_plus RL provided as data asset created from earlier step.\n", | ||
| " validation_dataset_path=test_data_path, # Path to test dataset\n", | ||
| " run_config={\n", | ||
| " \"num_nodes\": 1, # Number of nodes to be used for evaluation run.\n", | ||
| " \"number_of_gpu_to_use\": 8, # Number of GPUs in a node to be used for evaluation run.\n", | ||
| " \"base_path_1_label\": \"GRPO\", # Label to identify GRPO model outputs.\n", | ||
| " \"base_path_2_label\": \"RLPP\", # Label to identify RLPP model outputs.\n", | ||
| " \"explore_pattern_1\": \"global_step_{checkpoint}/actor/lora_adapter/\",\n", | ||
| " \"explore_pattern_2\": \"global_step_{checkpoint}/actor/lora_adapter/\",\n", | ||
| " \"checkpoint_values_1\": \"12\",\n", | ||
| " \"checkpoint_values_2\": \"12\",\n", | ||
| " \"use_lora_adapters_1\": True,\n", | ||
| " \"use_lora_adapters_2\": True,\n", | ||
| " \"evaluate_base_model\": True, # Set to True to evaluate base model along with RL finetuned models.\n", | ||
| " \"hf_model_id\": \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\", # Huggingface ID of the base model\n", | ||
| " \"max_prompt_length\": 8196,\n", | ||
| " \"max_response_length\": 1024,\n", | ||
| " \"dtype\": \"bfloat16\",\n", | ||
| " \"tensor_parallel_size\": 4,\n", | ||
| " }, # Configuration parameters for evaluation run.\n", | ||
| "grpo_model_asset_id = get_run_output_assetid(\n", | ||
| " ml_client, job_name=grpo_job.name, output_name=\"model_output\"\n", | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. seems to be taking older version of PR => #3709. Could we take a pull here please? |
||
| ") # get model asset ID from the grpo training run\n", | ||
| "rlpp_model_asset_id = get_run_output_assetid(\n", | ||
| " ml_client, job_name=rlpp_job.name, output_name=\"model_output\"\n", | ||
| ") # get model asset ID from the rlpp training run" | ||
| ] | ||
| }, | ||
| { | ||
| "cell_type": "code", | ||
| "execution_count": null, | ||
| "metadata": {}, | ||
| "outputs": [], | ||
| "source": [ | ||
| "eval_job, status = (\n", | ||
| " run_evaluation_pipeline( # Function which invokes the model evaluation pipeline.\n", | ||
| " ml_client=ml_client,\n", | ||
| " registry_ml_client=registry_ml_client,\n", | ||
| " compute_cluster=\"k8s-a100-compute\",\n", | ||
| " grpo_model_dir=grpo_model_asset_id, # Output from GPRO RL provided as data asset created from earlier step.\n", | ||
| " rlpp_model_dir=rlpp_model_asset_id, # Output from Reinforce_plus_plus RL provided as data asset created from earlier step.\n", | ||
| " validation_dataset_path=test_data_path, # Path to test dataset\n", | ||
| " run_config={ # Configuration to control base model, and also to point output location fo GRPO/RL++ runs.\n", | ||
| " \"num_nodes\": 1,\n", | ||
| " \"number_of_gpu_to_use\": 8,\n", | ||
| " \"base_path_1_label\": \"GRPO\",\n", | ||
| " \"base_path_2_label\": \"RLPP\",\n", | ||
| " \"explore_pattern_1\": \"global_step_{checkpoint}/actor/huggingface/\",\n", | ||
| " \"explore_pattern_2\": \"global_step_{checkpoint}/actor/lora_adapater/\",\n", | ||
| " \"checkpoint_values_1\": \"300,280,260,240\",\n", | ||
| " \"checkpoint_values_2\": \"264,260,240,220\",\n", | ||
| " \"use_lora_adapters_1\": False,\n", | ||
| " \"use_lora_adapters_2\": True,\n", | ||
| " \"evaluate_base_model\": True,\n", | ||
| " \"hf_model_id\": \"deepseek-ai/DeepSeek-R1-Distill-Qwen-7B\",\n", | ||
| " \"max_prompt_length\": 8196,\n", | ||
| " \"max_response_length\": 1024,\n", | ||
| " \"dtype\": \"bfloat16\",\n", | ||
| " \"tensor_parallel_size\": 4,\n", | ||
| " },\n", | ||
| " )\n", | ||
| ")" | ||
| ] | ||
| }, | ||
|
|
@@ -498,7 +515,7 @@ | |
| " num_epochs=1, # Number of train epochs to be run by draft trainer.\n", | ||
| " monitor=False, # Set to True to wait for completion.\n", | ||
| " base_model_mlflow_path=\"azureml://registries/azureml-meta/models/Meta-Llama-3-8B-Instruct/versions/9\",\n", | ||
| " draft_train_data_path=\"./data_for_draft_model/train/sharegpt_train_small.jsonl\",\n", | ||
| " draft_train_data_path=\"./data/draft_model/sharegpt_train_small.jsonl\",\n", | ||
| ")" | ||
| ] | ||
| }, | ||
|
|
@@ -591,8 +608,7 @@ | |
| "endpoint_name = deploy_speculative_decoding_endpoint(\n", | ||
| " ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n", | ||
| " combined_model=combined_model, # Reference from previous steps where combined model is created.\n", | ||
| " instance_type=\"octagepu\", # Instance type Kubernetes Cluster\n", | ||
| " compute_name=\"k8s-a100-compute\",\n", | ||
| " instance_type=\"Standard_NC40ads_H100_v5\", # Instance type\n", | ||
| ")" | ||
| ] | ||
| }, | ||
|
|
@@ -631,10 +647,9 @@ | |
| "outputs": [], | ||
| "source": [ | ||
| "# Deploy managed online endpoint with base model\n", | ||
| "base_endpoint_name = create_managed_deployment( # Function to create endpoint for base model.\n", | ||
| "base_endpoint_name = deploy_base_model_endpoint( # Function to create endpoint for base model.\n", | ||
| " ml_client=ml_client, # ML Client which specifies the workspace where endpoint gets deployed.\n", | ||
| " model_asset_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # Huggingface ID of the base model.\n", | ||
| " instance_type=\"Standard_ND96amsr_A100_v4\", # Compute SKU on which base model will be deployed.\n", | ||
| " instance_type=\"Standard_NC40ads_H100_v5\", # Compute SKU on which base model will be deployed.\n", | ||
| ")" | ||
| ] | ||
| }, | ||
|
|
@@ -711,10 +726,12 @@ | |
| "# Run evaluation job to compare base model and speculative decoding endpoints' performance\n", | ||
| "evaluation_job = run_evaluation_speculative_decoding(\n", | ||
| " ml_client=ml_client,\n", | ||
| " registry_ml_client=registry_ml_client,\n", | ||
| " base_endpoint_name=base_endpoint_name, # Base model endpoint from previous step.\n", | ||
| " speculative_endpoint_name=endpoint_name, # Speculative endpoint from previous step.\n", | ||
| " base_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n", | ||
| " speculative_model=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n", | ||
| " base_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in base endpoint, used for tokenization.\n", | ||
| " speculative_model_hf_id=\"meta-llama/Meta-Llama-3-8B-Instruct\", # HuggingFace repo ID of the model used in speculative decoding endpoint, used for tokenization.\n", | ||
| " compute_cluster=\"d13-v2\",\n", | ||
| ")" | ||
| ] | ||
| }, | ||
|
|
@@ -735,7 +752,7 @@ | |
| "cell_type": "markdown", | ||
| "metadata": {}, | ||
| "source": [ | ||
| "<img src=\"metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">" | ||
| "<img src=\"./images/metrics-base-target-spec-dec.png\" alt=\"Performance Metrics: Base Model vs Speculative Decoding\" style=\"max-width: 100%; height: auto;\">" | ||
| ] | ||
| } | ||
| ], | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could we prepare this dataset via script?