Skip to content

Error Running run_ray_serve_interleave with Llama3 8B #169

@ryanaoleary

Description

@ryanaoleary

I'm receiving an error when attempting to run:

ray job submit -- python run_ray_serve_interleave.py  --tpu_chips=4 --num_hosts=1 --size=8B --model_name=llama-3 --batch_size=8 --max_cache_length=2048 --tokenizer_path=$tokenizer_path --checkpoint_path=$output_ckpt_dir --quantize_weights=True --quantize_type="int8_per_channel" --quantize_kv_cache=True --sharding_config="default_shardings/llama.yaml"

on a single-host v4 TPU of 2x2x1 topology. The error is:

ray.exceptions.ActorDiedError: The actor died because of an error raised in its creation task, [36mray::PyTorchRayWorker.__init__()[39m (pid=5137, ip=10.168.0.16, actor_id=243ec964a2f41eae1707d84404000000, repr=<jetstream_pt.ray_worker.PyTorchRayWorker object at 0x7abe4074add0>)
File "/home/ray/jetstream-pytorch/jetstream_pt/ray_worker.py", line 200, in __init__
 pt_model = model_exportable.Transformer(args, env)
File "/home/ray/jetstream-pytorch/jetstream_pt/third_party/llama/model_exportable.py", line 192, in __init__
 self.tok_embeddings = Embedding(
File "/home/ray/jetstream-pytorch/jetstream_pt/layers.py", line 57, in __init__
 table = torch.ones(
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/wrappers.py", line 252, in _fn
 result = fn(*args, **kwargs)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_refs/__init__.py", line 4774, in ones
 size = utils.extract_shape_from_varargs(size)
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 854, in extract_shape_from_varargs
 validate_shape(shape) # type: ignore[arg-type]
File "/home/ray/anaconda3/lib/python3.10/site-packages/torch/_prims_common/__init__.py", line 588, in validate_shape
 validate_dim_length(l)

This seems to be related to this logic in ray_worker.py that creates the pt_model:

env_data.model_type = "llama-2-" + param_size
env_data.num_layers = args.n_layers
env = JetEngineEnvironment(env_data)
pt_model = model_exportable.Transformer(args, env)

should the model_type for llama models be hardcoded to llama-2?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions