Skip to content

PD heterogenous TP #77

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 19 commits into
base: disagg_pd_dev
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 144 additions & 54 deletions vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class NixlAgentMetadata(
agent_metadata: bytes
kv_caches_base_addr: list[int]
num_blocks: int
tp_size: int
block_len: int


@dataclass
Expand Down Expand Up @@ -223,8 +225,8 @@ def __init__(self, engine_id: str):

# Agent.
self.nixl_wrapper = NixlWrapper(str(uuid.uuid4()), None)
# Map of engine_id -> agent_name.
self._remote_agents: dict[str, str] = {}
# Map of engine_id -> {rank0: agent_name0, rank1: agent_name1..}.
self._remote_agents: dict[str, dict[int, str]] = defaultdict(dict)

# Metadata.
self.engine_id = engine_id
Expand All @@ -235,20 +237,23 @@ def __init__(self, engine_id: str):
# KV Caches and nixl tracking data.
self.kv_caches: dict[str, torch.Tensor] = {}

# Map of engine_id -> kv_caches_base_addr
self.kv_caches_base_addr: dict[str, list[int]] = {}
# Map of engine_id -> kv_caches_base_addr. For TP case, each local
# rank will still only pull from a single remote TP worker.
self.kv_caches_base_addr: dict[str, list[int]] = dict()

# Number of NIXL regions. Currently one region per cache
# (so 1 per layer for MLA, otherwise 2 per layer)
self.num_regions = 0

# nixl_prepped_dlist_handle (int).
self.src_xfer_side_handle: int = 0
# nixl_prepped_dlist_handle. Different dst TP sizes require preparing
# xfer layout differently.
self.src_xfer_side_handle: dict[int, int] = dict()
# Map of engine_id -> nixl_prepped_dlist_handle (int)].
self.dst_xfer_side_handles: dict[str, int] = {}
self.dst_xfer_side_handles: dict[str, int] = dict()

# Map of engine_id -> num_blocks.
self.dst_num_blocks: dict[str, int] = {}
# Map of engine_id -> num_blocks. Remote TP ranks will have the same
# number of blocks.
self.dst_num_blocks: dict[str, int] = dict()
self._registered_descs: list[Any] = []

# In progress transfers.
Expand All @@ -267,6 +272,8 @@ def __init__(self, engine_id: str):
# Background thread for establishing new connections.
self._nixl_handshake_listener_t: Optional[threading.Thread] = None

self._tp_size = {self.engine_id: self.world_size}

@staticmethod
def _nixl_handshake_listener(metadata: NixlAgentMetadata,
ready_event: threading.Event, rank: int):
Expand All @@ -290,6 +297,7 @@ def _nixl_handshake_listener(metadata: NixlAgentMetadata,
# NOTE(rob): we need each rank to have a unique port. This
# hack to keeps us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
# TODO get rank port util
port = envs.VLLM_NIXL_SIDE_CHANNEL_PORT + rank
path = f"tcp://{host}:{port}"
logger.debug("Starting listening on path: %s", path)
Expand All @@ -306,12 +314,12 @@ def _nixl_handshake(self, host: str, port: int):
"""Do a NIXL handshake with a remote instance."""

start_time = time.perf_counter()

# NOTE(rob): we need each rank to have a unique port. This is
# a hack to keep us moving. We will switch when moving to etcd
# or where we have a single ZMQ socket in the scheduler.
path = f"tcp://{host}:{port + self.rank}"
logger.debug("Querying metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:

def handshake(sock, rank: int) -> NixlAgentMetadata:
# Send query for the request.
sock.send(GET_META_MSG)
metadata_bytes = sock.recv()
Expand All @@ -320,13 +328,32 @@ def _nixl_handshake(self, host: str, port: int):
got_metadata_time = time.perf_counter()

# Register Remote agent.
self.add_remote_agent(metadata)
self.add_remote_agent(metadata, rank)
setup_agent_time = time.perf_counter()

logger.debug("NIXL handshake: get metadata took: %s",
got_metadata_time - start_time)
logger.debug("NIXL handshake: add agent took: %s",
setup_agent_time - got_metadata_time)
return metadata

# Handshake with remote agent-rank0 first to get the tp_size of remote
path = f"tcp://{host}:{port}"
logger.debug("Querying master rank metadata on path: %s", path)
with zmq_ctx(zmq.REQ, path) as sock:
metadata = handshake(sock, 0)

# Handshake only with the other TP remote the current local rank will
# pull from. With homogeneous TP it happens to be the same rank_i.
d_workers_per_p_worker = self._tp_size[
self.engine_id] // metadata.tp_size
p_remote_rank = self.rank // d_workers_per_p_worker
if p_remote_rank > 0:
path = f"tcp://{host}:{port + p_remote_rank}"
logger.debug("Querying metadata on path: %s at remote rank %s",
path, p_remote_rank)
with zmq_ctx(zmq.REQ, path) as sock:
metadata = handshake(sock, p_remote_rank)

def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
"""Register the KV Cache data in nixl."""
Expand All @@ -341,14 +368,20 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
self.num_blocks = first_kv_cache.shape[0]
block_rank = 2 # [block_size, latent_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_size, kv_latent_dim = block_shape
self.kv_dim = kv_elem_size * kv_latent_dim
else:
# [2 (k and v), num_blocks, ...]
# [2 (k and v), num_blocks, block_size, kv_heads, head_dim]
self.num_blocks = first_kv_cache.shape[1]
block_rank = 3 # [block_size, kv_heads, head_dim]
block_shape = first_kv_cache.shape[-block_rank:]
self.block_size, n_kv_heads, head_dim = block_shape
# head size in bytes.
self.kv_dim = kv_elem_size * n_kv_heads * head_dim

# TODO(tms): self.block_len needs to be per-layer for sliding window,
# hybrid attn, etc
# block size in bytes
self.block_len = kv_elem_size * math.prod(block_shape)

logger.debug("Registering KV_Caches. use_mla: %s, shape %s", use_mla,
Expand Down Expand Up @@ -382,7 +415,6 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
logger.debug("Registering descs: %s", caches_data)
self.nixl_wrapper.register_memory(descs)
logger.debug("Done registering descs")

self._registered_descs.append(descs)

# After KV Caches registered, listen for new connections.
Expand All @@ -391,7 +423,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
agent_metadata=self.nixl_wrapper.get_agent_metadata(),
kv_caches_base_addr=self.kv_caches_base_addr[self.engine_id],
num_blocks=self.num_blocks,
)
tp_size=self.world_size,
block_len=self.block_len)
ready_event = threading.Event()
self._nixl_handshake_listener_t = threading.Thread(
target=self._nixl_handshake_listener,
Expand All @@ -401,49 +434,97 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]):
self._nixl_handshake_listener_t.start()
ready_event.wait()

def add_remote_agent(self, nixl_agent_meta: NixlAgentMetadata):
def add_remote_agent(self,
nixl_agent_meta: NixlAgentMetadata,
remote_rank: int = 0):
engine_id = nixl_agent_meta.engine_id
if engine_id in self._remote_agents:
# TODO re-evaluate refreshing for scaling/recovery
if (engine_id in self._remote_agents and \
remote_rank in self._remote_agents[engine_id]):
return

self._remote_agents[engine_id] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
if engine_id in self._tp_size:
assert self._tp_size[engine_id] == nixl_agent_meta.tp_size
self._tp_size[engine_id] = nixl_agent_meta.tp_size
self._remote_agents[engine_id][
remote_rank] = self.nixl_wrapper.add_remote_agent(
nixl_agent_meta.agent_metadata)

d_workers_per_p_worker = self._tp_size[
self.engine_id] // self._tp_size[engine_id]
assert d_workers_per_p_worker > 0, "Decode TP cannot be smaller than"
" prefill TP"

# TODO we should also check hidden_dim and kv precision, they must match
remote_block_size = nixl_agent_meta.block_len / (
self.kv_dim * d_workers_per_p_worker)
assert self.block_size == remote_block_size, "Remote P worker with "
"different block size is not supported"

# Create src descs and xfer side handles.
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.src_xfer_side_handle = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)

# Create dst descs and xfer side handles.
if d_workers_per_p_worker not in self.src_xfer_side_handle:
blocks_data = []
for base_addr in self.kv_caches_base_addr[self.engine_id]:
# NOTE With heter-TP, more blocks are prepared than what are
# needed as self.num_blocks >= nixl_agent_meta.num_blocks. We
# could create fewer, but then _get_block_descs_ids needs to
# select agent_meta.num_blocks instead of self.num_blocks for
# local descr, and that makes handling regular flow less clean.
for block_id in range(self.num_blocks):
block_offset = block_id * self.block_len
for b in range(self.block_size):
head_offset = b * self.kv_dim
addr = base_addr + block_offset + head_offset
# (addr, len, device id)
blocks_data.append((addr, self.kv_dim, self.rank))
logger.debug("Created %s blocks for src engine %s and rank %s",
len(blocks_data), self.engine_id, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
# NIXL_INIT_AGENT to be used for preparations of local descs.
self.src_xfer_side_handle[
d_workers_per_p_worker] = self.nixl_wrapper.prep_xfer_dlist(
"NIXL_INIT_AGENT", descs)

# Create dst descs and xfer side handles. TP workers have same #blocks.
if engine_id in self.dst_num_blocks:
assert self.dst_num_blocks[engine_id] == nixl_agent_meta.num_blocks

self.dst_num_blocks[engine_id] = nixl_agent_meta.num_blocks

blocks_data = []
for base_addr in self.kv_caches_base_addr[engine_id]:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * self.block_len
# (addr, len, device id)
blocks_data.append(
(base_addr + block_offset, self.block_len, self.rank))
logger.debug("Created %s blocks for dst engine %s and rank %s",
len(blocks_data), engine_id, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id], descs)
# With homogeneous TP, D pulls the whole kv cache from corresponding
# rank. With heterogeneous TP, prepare the descriptors by splitting the
# P KV cache along kv_head dim, of D worker's kv_head size (D>P).
# Eg. PTP1 DTP2 => P0 KV:[block0-KV_0 | block0-KV_1..].
p_remote_rank = self.rank // d_workers_per_p_worker
# Only register the remote's descriptors if current rank pulls from it.
if p_remote_rank == remote_rank:
self.kv_caches_base_addr[
engine_id] = nixl_agent_meta.kv_caches_base_addr
rank_offset = self.rank % d_workers_per_p_worker * self.kv_dim
# Register all remote blocks, but only the corresponding kv heads.
for base_addr in nixl_agent_meta.kv_caches_base_addr:
for block_id in range(nixl_agent_meta.num_blocks):
block_offset = block_id * nixl_agent_meta.block_len
for b in range(self.block_size):
# Remote kv_dim = local kv_dim * d_workers_per_p_worker
head_offset = b * self.kv_dim * d_workers_per_p_worker
addr = base_addr + block_offset + head_offset
# (addr, len, device id)
blocks_data.append(
(addr + rank_offset, self.kv_dim, remote_rank))
logger.debug(
"Created %s blocks for dst engine %s with remote rank %s and " \
"local rank %s",
len(blocks_data), engine_id, remote_rank, self.rank)

# Register with NIXL.
descs = self.nixl_wrapper.get_xfer_descs(blocks_data, "VRAM")
self.dst_xfer_side_handles[
engine_id] = self.nixl_wrapper.prep_xfer_dlist(
self._remote_agents[engine_id][remote_rank], descs)

def get_finished(self) -> tuple[set[str], set[str]]:
"""
Expand Down Expand Up @@ -580,6 +661,7 @@ def _read_blocks(
request_id: str,
):
# NOTE(rob): this takes ~2s. We need to get this off the hotpath.
# TODO check remote_rank in here too?
if dst_engine_id not in self._remote_agents:
self._nixl_handshake(remote_host, remote_port)

Expand All @@ -595,9 +677,15 @@ def _read_blocks(

assert len(local_block_ids) > 0
assert len(local_block_ids) == len(remote_block_ids)
# NOTE (nicolo) With homogeneous TP, each TP worker loads KV from
# corresponding rank. With heterogeneous TP, fixing D>P, the D tp
# workers will issue xfers to parts of the P worker remote kv caches.

# Get side handles.
local_xfer_side_handle = self.src_xfer_side_handle
d_workers_per_p_worker = self._tp_size[
self.engine_id] // self._tp_size[dst_engine_id]
local_xfer_side_handle = self.src_xfer_side_handle[
d_workers_per_p_worker]
remote_xfer_side_handle = self.dst_xfer_side_handles[dst_engine_id]

# Get descs ids.
Expand Down Expand Up @@ -635,7 +723,9 @@ def _get_block_descs_ids(self, engine_id: str,
descs_ids: list[int] = []
for reg_id in region_ids:
for block_id in block_ids:
descs_ids.append(reg_id * num_blocks + block_id)
for kv_block in range(self.block_size):
descs_ids.append(reg_id * num_blocks * self.block_size +
block_id * self.block_size + kv_block)
return descs_ids


Expand Down