Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add v5e-4 and quantization #303

Draft
wants to merge 2 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion dags/inference/configs/maxtext_inference_gce_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def get_maxtext_inference_nightly_config(
"ici_tensor_parallelism": f"{model_configs['ici_tensor_parallelism']}",
"per_device_batch_size": f"{model_configs['per_device_batch_size']}",
"weight_dtype": f"{model_configs['weight_dtype']}",
"quantization": f"{model_configs['quantization']}",
}

run_model_cmds = (
Expand All @@ -104,6 +105,7 @@ def get_maxtext_inference_nightly_config(
f"export SCAN_LAYERS={model_configs['scan_layers']}",
f"export WEIGHT_DTYPE={model_configs['weight_dtype']}",
f"export PER_DEVICE_BATCH_SIZE={model_configs['per_device_batch_size']}",
f"export QUANTIZATION={model_configs['quantization']}",
# Start JetStream MaxText server in the background
"""python MaxText/maxengine_server.py \
MaxText/configs/base.yml \
Expand All @@ -117,6 +119,7 @@ def get_maxtext_inference_nightly_config(
ici_tensor_parallelism=${ICI_TENSOR_PARALLELISM} \
scan_layers=${SCAN_LAYERS} \
weight_dtype=${WEIGHT_DTYPE} \
quantization=${QUANTIZATION} \
per_device_batch_size=${PER_DEVICE_BATCH_SIZE} > /dev/null 2>&1 &""",
"cd ..",
# Give server time to start
Expand All @@ -129,7 +132,7 @@ def get_maxtext_inference_nightly_config(
--dataset {model_configs['dataset']} \
--max-output-length {model_configs['max_output_length']} \
--request-rate {model_configs['request_rate']} \
--warmup-first true \
--warmup-mode sampled \
--save-result \
--additional-metadata-metrics-to-save ${{METADATA_DICT}} \
--save-request-outputs \
Expand Down
276 changes: 189 additions & 87 deletions dags/inference/maxtext_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@
catchup=False,
) as dag:
test_name_prefix = "maxtext-inference"
test_models = {
"llama2-7b": {
test_models = [
# llama2-7b: 8-chip, bf16
("llama2-7b", {
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.V5P, 8)],
"checkpoint": "gs://inference-benchmarks/models/llama2-7b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items",
Expand All @@ -54,8 +55,31 @@
"max_prefill_predict_length": 1024,
"max_target_length": 2048,
"max_output_length": 1024,
},
"llama2-13b": {
"quantization": "",
}),
# llama2-7b: 4-chip, bf16
("llama2-7b", {
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 4)],
"checkpoint": "gs://inference-benchmarks/models/llama2-7b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items",
"model_mode": "base",
"maxtext_logs": "gs://inference-benchmarks/models/llama2-7b/2024-04-25-14-01/",
"scan_layers": "false",
"dataset": "openorca",
"weight_dtype": "bfloat16",
"tokenizer": "tokenizer.llama2",
"per_device_batch_sizes": [1, 2, 4, 8, 11, 12],
# (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism)
"ici_parallelisms": [(1, -1, 1), (1, 1, -1)],
"request_rate": 5,
"num_prompts": 1000,
"max_prefill_predict_length": 1024,
"max_target_length": 2048,
"max_output_length": 1024,
"quantization": "",
}),
# llama2-13b: 8-chip, bf16
("llama2-13b", {
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.V5P, 8)],
"checkpoint": "gs://inference-benchmarks/models/llama2-13b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items",
Expand All @@ -73,11 +97,55 @@
"max_prefill_predict_length": 1024,
"max_target_length": 2048,
"max_output_length": 1024,
},
"llama2-70b": {
"quantization": "",
}),
# llama2-13b: 4-chip, bf16
("llama2-13b", {
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 4)],
"checkpoint": "gs://inference-benchmarks/models/llama2-13b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items",
"model_mode": "base",
"maxtext_logs": "gs://inference-benchmarks/models/llama2-13b/2024-04-25-14-01/",
"scan_layers": "false",
"dataset": "openorca",
"weight_dtype": "bfloat16",
"tokenizer": "tokenizer.llama2",
"per_device_batch_sizes": [1, 2, 4, 5, 6],
# (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism)
"ici_parallelisms": [(1, -1, 1), (1, 1, -1)],
"request_rate": 5,
"num_prompts": 1000,
"max_prefill_predict_length": 1024,
"max_target_length": 2048,
"max_output_length": 1024,
"quantization": "",
}),
# llama2-70b: 8-chip, bf16
("llama2-70b", {
"sleep_time": 240,
"tpu_version_cores": [(TpuVersion.V5P, 8)],
"per_device_batch_sizes": [12, 16, 20, 24, 32, 48],
"checkpoint": "gs://inference-benchmarks/models/llama2-70b-chat/2024-05-08-23-16/param-only-decode-ckpt-maxtext/checkpoints/0/items",
"model_mode": "chat",
"maxtext_logs": "gs://inference-benchmarks/models/llama2-70b-chat/2024-05-08-23-16/",
"scan_layers": "false",
"dataset": "openorca",
"weight_dtype": "bfloat16",
"tokenizer": "tokenizer.llama2",
# (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism)
"ici_parallelisms": [(1, -1, 1), (1, 1, -1)],
"request_rate": 5,
"num_prompts": 1000,
"max_prefill_predict_length": 1024,
"max_target_length": 2048,
"max_output_length": 1024,
"quantization": "",
}),
# llama2-70b: 8-chip, int8
("llama2-70b", {
"sleep_time": 240,
"tpu_version_cores": [(TpuVersion.V5P, 8)],
"per_device_batch_sizes": [12, 16, 20, 24],
"per_device_batch_sizes": [12, 16, 20, 24, 32, 48],
"checkpoint": "gs://inference-benchmarks/models/llama2-70b-chat/2024-05-08-23-16/param-only-decode-ckpt-maxtext/checkpoints/0/items",
"model_mode": "chat",
"maxtext_logs": "gs://inference-benchmarks/models/llama2-70b-chat/2024-05-08-23-16/",
Expand All @@ -92,8 +160,10 @@
"max_prefill_predict_length": 1024,
"max_target_length": 2048,
"max_output_length": 1024,
},
"gemma-7b": {
"quantization": "int8",
}),
# gemma-7b: 8-chip, bf16
("gemma-7b", {
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 8), (TpuVersion.V5P, 8)],
"checkpoint": "gs://inference-benchmarks/models/gemma-7b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items",
Expand All @@ -111,84 +181,116 @@
"max_prefill_predict_length": 1024,
"max_target_length": 2048,
"max_output_length": 1024,
},
}
"quantization": "",
}),
# gemma-7b: 4-chip, bf16
("gemma-7b", {
"sleep_time": 120,
"tpu_version_cores": [(TpuVersion.V5E, 4)],
"checkpoint": "gs://inference-benchmarks/models/gemma-7b/2024-04-25-14-01/param-only-decode-ckpt-maxtext/checkpoints/0/items",
"model_mode": "base",
"maxtext_logs": "gs://inference-benchmarks/models/gemma-7b/2024-04-25-14-01/",
"scan_layers": "false",
"dataset": "openorca",
"weight_dtype": "bfloat16",
"tokenizer": "tokenizer.gemma",
"per_device_batch_sizes": [1, 2, 4, 8, 11, 12],
# (ici_fsdp_parallelism, ici_autoregressive_parallelism, ici_tensor_parallelism)
"ici_parallelisms": [(1, -1, 1), (1, 1, -1)],
"request_rate": 5,
"num_prompts": 1000,
"max_prefill_predict_length": 1024,
"max_target_length": 2048,
"max_output_length": 1024,
"quantization": "",
}),
]

for model, sweep_model_configs in test_models:
for ici_parallelism in sweep_model_configs["ici_parallelisms"]:
for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]:
dags = []
for test_mode in [SetupMode.NIGHTLY, SetupMode.STABLE]:
for per_device_batch_size in sweep_model_configs["per_device_batch_sizes"]:

# Set per_device_batch_size to a single value, not a list
model_configs = {}
model_configs["model_name"] = model
model_configs["model_mode"] = sweep_model_configs["model_mode"]
model_configs["sleep_time"] = sweep_model_configs["sleep_time"]
model_configs["checkpoint"] = sweep_model_configs["checkpoint"]
model_configs["maxtext_logs"] = sweep_model_configs["maxtext_logs"]
model_configs["scan_layers"] = sweep_model_configs["scan_layers"]
model_configs["dataset"] = sweep_model_configs["dataset"]
model_configs["weight_dtype"] = sweep_model_configs["weight_dtype"]
model_configs["tokenizer"] = sweep_model_configs["tokenizer"]
model_configs["per_device_batch_size"] = per_device_batch_size
ici_fsdp = ici_parallelism[0]
ici_ar = ici_parallelism[1]
ici_tensor = ici_parallelism[2]
model_configs["ici_fsdp_parallelism"] = ici_fsdp
model_configs["ici_autoregressive_parallelism"] = ici_ar
model_configs["ici_tensor_parallelism"] = ici_tensor
model_configs["request_rate"] = sweep_model_configs["request_rate"]
model_configs["num_prompts"] = sweep_model_configs["num_prompts"]
model_configs["max_target_length"] = sweep_model_configs[
"max_target_length"
]
model_configs["max_prefill_predict_length"] = sweep_model_configs[
"max_prefill_predict_length"
]
model_configs["max_output_length"] = sweep_model_configs[
"max_output_length"
]
quant = sweep_model_configs["quantization"]
quant_name = quant if quant else "bf16"
model_configs["quantization"] = quant

for model, sweep_model_configs in test_models.items():
# tasks_per_model = []
for per_device_batch_size in sweep_model_configs["per_device_batch_sizes"]:
for ici_parallelism in sweep_model_configs["ici_parallelisms"]:
for tpu_version, tpu_cores in sweep_model_configs["tpu_version_cores"]:
# Set per_device_batch_size to a single value, not a list
model_configs = {}
model_configs["model_name"] = model
model_configs["model_mode"] = sweep_model_configs["model_mode"]
model_configs["sleep_time"] = sweep_model_configs["sleep_time"]
model_configs["checkpoint"] = sweep_model_configs["checkpoint"]
model_configs["maxtext_logs"] = sweep_model_configs["maxtext_logs"]
model_configs["scan_layers"] = sweep_model_configs["scan_layers"]
model_configs["dataset"] = sweep_model_configs["dataset"]
model_configs["weight_dtype"] = sweep_model_configs["weight_dtype"]
model_configs["tokenizer"] = sweep_model_configs["tokenizer"]
model_configs["per_device_batch_size"] = per_device_batch_size
ici_fsdp = ici_parallelism[0]
ici_ar = ici_parallelism[1]
ici_tensor = ici_parallelism[2]
model_configs["ici_fsdp_parallelism"] = ici_fsdp
model_configs["ici_autoregressive_parallelism"] = ici_ar
model_configs["ici_tensor_parallelism"] = ici_tensor
model_configs["request_rate"] = sweep_model_configs["request_rate"]
model_configs["num_prompts"] = sweep_model_configs["num_prompts"]
model_configs["max_target_length"] = sweep_model_configs[
"max_target_length"
]
model_configs["max_prefill_predict_length"] = sweep_model_configs[
"max_prefill_predict_length"
]
model_configs["max_output_length"] = sweep_model_configs[
"max_output_length"
]
if tpu_version == TpuVersion.V5E:
# v5e benchmarks
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
zone = Zone.US_EAST1_C.value
network = V5_NETWORKS
subnetwork = V5E_SUBNETWORKS
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value
elif tpu_version == TpuVersion.V5P:
zone = Zone.US_EAST5_A.value
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
network = V5_NETWORKS
subnetwork = V5P_SUBNETWORKS

if tpu_version == TpuVersion.V5E:
# v5e benchmarks
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
zone = Zone.US_EAST1_C.value
network = V5_NETWORKS
subnetwork = V5E_SUBNETWORKS
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5_LITE.value
elif tpu_version == TpuVersion.V5P:
zone = Zone.US_EAST5_A.value
runtime_version = RuntimeVersion.V2_ALPHA_TPUV5.value
project_name = Project.TPU_PROD_ENV_AUTOMATED.value
network = V5_NETWORKS
subnetwork = V5P_SUBNETWORKS
# maxtext_stable_1slice = maxtext_inference_gce_config.get_maxtext_inference_nightly_config(
# tpu_version=tpu_version,
# tpu_cores=tpu_cores,
# tpu_zone=zone,
# runtime_version=runtime_version,
# project_name=project_name,
# time_out_in_min=60,
# is_tpu_reserved=True,
# test_name=f"{test_name_prefix}-stable-{model}-batch-{per_device_batch_size}-ici-fsdp{ici_fsdp}-ar{ici_ar}-tensor{ici_tensor}",
# test_mode=SetupMode.STABLE,
# network=network,
# subnetwork=subnetwork,
# model_configs=model_configs,
# )
maxtext_nightly_1slice = maxtext_inference_gce_config.get_maxtext_inference_nightly_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
runtime_version=runtime_version,
project_name=project_name,
time_out_in_min=60,
is_tpu_reserved=True,
test_name=f"{test_name_prefix}-{test_mode.value}-{model}-quant-{quant_name}-ici-fsdp{ici_fsdp}-ar{ici_ar}-tensor{ici_tensor}-batch-{per_device_batch_size}",
test_mode=test_mode,
network=network,
subnetwork=subnetwork,
model_configs=model_configs,
)
# dags.append(maxtext_stable_1slice)
dags.append(maxtext_nightly_1slice)
# maxtext_stable_1slice >> maxtext_nightly_1slice

maxtext_stable_1slice = maxtext_inference_gce_config.get_maxtext_inference_nightly_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
runtime_version=runtime_version,
project_name=project_name,
time_out_in_min=60,
is_tpu_reserved=True,
test_name=f"{test_name_prefix}-stable-{model}-per_device_batch_size-{per_device_batch_size}-ici-fsdp{ici_fsdp}-ar{ici_ar}-tensor{ici_tensor}",
test_mode=SetupMode.STABLE,
network=network,
subnetwork=subnetwork,
model_configs=model_configs,
)
maxtext_nightly_1slice = maxtext_inference_gce_config.get_maxtext_inference_nightly_config(
tpu_version=tpu_version,
tpu_cores=tpu_cores,
tpu_zone=zone,
runtime_version=runtime_version,
project_name=project_name,
time_out_in_min=60,
is_tpu_reserved=True,
test_name=f"{test_name_prefix}-nightly-{model}-per_device_batch_size-{per_device_batch_size}-ici-fsdp{ici_fsdp}-ar{ici_ar}-tensor{ici_tensor}",
test_mode=SetupMode.NIGHTLY,
network=network,
subnetwork=subnetwork,
model_configs=model_configs,
)
maxtext_stable_1slice >> maxtext_nightly_1slice
for i in range(1, len(dags)):
dags[i-1] >> dags[i]
Loading