Skip to content

lmdeploy support kernel block size#4421

Open
Tsundoku958 wants to merge 8 commits intoInternLM:mainfrom
Tsundoku958:Tsundoku958/support-kernel-block-size
Open

lmdeploy support kernel block size#4421
Tsundoku958 wants to merge 8 commits intoInternLM:mainfrom
Tsundoku958:Tsundoku958/support-kernel-block-size

Conversation

@Tsundoku958
Copy link
Copy Markdown
Contributor

@Tsundoku958 Tsundoku958 commented Mar 17, 2026

Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily receiving feedbacks. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers.

Motivation

Currently, using a large block size (e.g., >=128) for the KV cache triggers a Triton kernel shared memory overflow. However, users maybe require the flexibility to set an arbitrary block size to control the granularity for the Prefix Cache and Block Manage.

This PR decouples the block size for cache management from the block size for kernel execution. It introduces two configurable parameters:

  • manager_block_size: The block size in the view of Block Manager.
  • kernel_block_size: The block size in the view of Triton kernels.

Modification

  1. The KV cache in the Cache Engine is now allocated and structured based on the kernel_block_size.
  2. The Block Manager performs all its operations (allocation, swapping, scheduling) using the manager_block_size as its fundamental unit.
  3. When creating model inputs, converts manager_block_ids from the Block Manager into kernel_block_ids for the cuda kernel.

before pr:
lmdeploy serve api_server ../Qwen3-Next-80B-A3B-Thinking/ --backend pytorch --max-batch-size 16 --tp 4 --cache-block-seq-len 128
Traceback

  File "/home/lmdeploy/lmdeploy/pytorch/backends/cuda/attention/default.py", line 206, in _forward_decoding
    attn_output = self.paged_attention_fwd(
  File "/home/lmdeploy/lmdeploy/pytorch/kernels/cuda/pagedattention.py", line 721, in flash_attn_with_kvcache
    _fwd_grouped_split_kernel[grid](q,
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 390, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/py3/lib/python3.10/site-packages/triton/runtime/jit.py", line 617, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 498, in __getattribute__
    self._init_handles()
  File "/opt/py3/lib/python3.10/site-packages/triton/compiler/compiler.py", line 483, in _init_handles
    raise OutOfResources(self.metadata.shared, max_shared, "shared memory")
triton.runtime.errors.OutOfResources: out of resource: shared memory, Required: 274432, Hardware limit: 232448. Reducing block sizes or `num_stages` may help.

After pr:
lmdeploy serve api_server ../Qwen3-Next-80B-A3B-Thinking/ --backend pytorch --max-batch-size 16 --tp 4 --cache-block-seq-len 128 --kernel-block-size 32

@Tsundoku958 Tsundoku958 marked this pull request as ready for review March 18, 2026 11:38
@lvhan028
Copy link
Copy Markdown
Collaborator

@grimoire Please evaluate the necessity of introducing the parameter kernel_block_size

@Tsundoku958
Copy link
Copy Markdown
Contributor Author

Tsundoku958 commented Mar 19, 2026

@grimoire Please evaluate the necessity of introducing the parameter kernel_block_size

Another motivation for this pull request is that I hope to enable shared GPU memory for the linear attention cache and the full attention's key-value cache in qwen3-next. If the recurrent_cache can be stored within a single block, then linear attention cache management can be performed with the help of the block manager. However, I encountered an issue: the current block size is limited to a maximum of 64, so the recurrent_cache requires 8 consecutive blocks for storage, which would make block management very complex. After this PR, I can set the block size to a large value for storing a single recurrent_cache, while the kernel block size still uses 64.

head_size = model_config.head_dim
shape = cls._get_key_block_shape_impl(
model_config,
block_size=cache_config.block_size,
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

init

self.block_size = cache_config.block_size

allocate_custom_cache

            custom_shape = self.get_custom_cache_shape_impl(
                num_layers=num_layers,
                num_blocks=self.num_gpu_blocks,
                block_size=self.block_size,
                shape=shape,
            )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not quite sure what this means. Could you please explain it more clearly?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block_offsets = []
for seq in seqs:
block_offset = self.block_manager.get_block_table(seq)
block_offset = block_offset.repeat(self.kernel_blocks_per_kv) * self.kernel_blocks_per_kv + np.tile(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

repeat and tile would allocate new memory as output. Try do this with broadcast.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And we can do this in inputs_maker after tensorlize. Computing the new block table in a big ndarray/tensor should be better than looping.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's done

@grimoire
Copy link
Copy Markdown
Collaborator

Another motivation for this pull request is that I hope to enable shared GPU memory for the linear attention cache and the

Cool, this is also on our development roadmap, and I'm glad to see community developers getting involved.

However, since this feature touches many core modules, we're cautious about adding it to the engine right after integrating Qwen3-next and Qwen3.5 — each large feature like this could become future technical debt. If you submit changes in this area, we can't guarantee a timely review or merge.

@Tsundoku958 Tsundoku958 force-pushed the Tsundoku958/support-kernel-block-size branch from 9ea365d to a8657d3 Compare March 23, 2026 12:24
off_len = len(off)
out[idx, :off_len] = off
off_len = len(off) * kernel_blocks_per_kv
out[idx, :off_len] = (off[:, None] * kernel_blocks_per_kv + kernel_block_arange).reshape(-1)
Copy link
Copy Markdown
Collaborator

@grimoire grimoire Mar 24, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_tensorlize_block_offsets consumes significant CPU resources — avoid block expanding unless necessary. Additionally, we should be able to perform this directly on out after the loop. Repeatedly evaluating evalframes inside the loop introduces substantial overhead from extra Python function calls.

Also, I'd prefer adding a separate function after _tensorlize_block_offsets to handle this, rather than modifying _tensorlize_block_offsets itself.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is another "get block table" behaviour in create_model_inputs_long_context without using _tensorlize_block_offsets

@Tsundoku958
Copy link
Copy Markdown
Contributor Author

Could you review this again? @grimoire

Copy link
Copy Markdown
Collaborator

@grimoire grimoire left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Tsundoku958
Copy link
Copy Markdown
Contributor Author

Hi @lvhan028 , this PR was approved a while ago but hasn't been merged yet. Is there anything else needed on my end, or is it ready to go?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants