Skip to content

Conversation

@xiehaohx
Copy link

ring_attn: fp8_comm & kv_fusion

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @xiehaohx, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant optimizations to the RingAttnWeight class by enabling FP8 communication and a Key-Value (KV) fusion strategy. These enhancements are designed to boost the efficiency of distributed attention mechanisms, primarily by minimizing data transfer sizes during communication and streamlining the attention computation when keys and values are combined. The modifications provide greater flexibility in configuring the attention mechanism for improved performance and resource utilization.

Highlights

  • FP8 Communication Support: The RingAttnWeight class now supports 8-bit floating point (FP8) communication for key and value tensors, which can significantly reduce communication overhead and memory footprint in distributed attention mechanisms. The previous assertion preventing FP8 communication has been removed.
  • Key-Value (KV) Fusion: A new use_kv_fusion flag has been introduced, allowing for the fusion of key and value tensors into a single kv tensor. This fused tensor is then used for both communication and attention computation, potentially improving efficiency.
  • Refactored apply Method: The main apply method within RingAttnWeight has been updated to conditionally handle the new FP8 communication and KV fusion logic, providing greater flexibility and optimization options for attention computation.
  • New ring_attn_sub_kv_fusion Method: A dedicated helper method, ring_attn_sub_kv_fusion, has been added to manage the attention computation specifically when KV fusion is enabled. This method utilizes flash_attn by appropriately splitting the fused kv tensor.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for FP8 communication and key-value tensor fusion within the Ring Attention mechanism. These are valuable optimizations that can enhance performance by reducing communication overhead. The implementation adds use_fp8_comm and use_kv_fusion flags to enable these features. My review has identified a critical bug related to the KV fusion logic in the final attention step, which would result in incorrect outputs. Additionally, there's a significant amount of code duplication that complicates maintenance. I've provided detailed comments and suggestions to address these issues.


block_out, block_lse = self.ring_attn_sub(q, k, v)
if use_kv_fusion:
next_kv = torch.cat((kv, txt_kv), dim=1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There is a critical bug in the logic for the final step of the ring attention when use_kv_fusion is enabled. The result of torch.cat((kv, txt_kv), dim=1) is assigned to next_kv, but the subsequent attention calculation at line 164 still uses the original kv. This means the text tokens (txt_kv) are not included in the attention computation for the final block, leading to incorrect results. The else branch for the non-fused case correctly updates k and v in-place. The same should be done for kv.

                    kv = torch.cat((kv, txt_kv), dim=1)

Comment on lines 127 to 182
for step in range(world_size):
if step + 1 != world_size:
next_k = RING_COMM.send_recv(k)
next_v = RING_COMM.send_recv(v)
RING_COMM.commit()
if use_fp8_comm:
if use_kv_fusion:
kv_fp8, kv_scale = quant_fp8_vllm(kv.reshape(-1, hidden_dims))
kv_fp8 = kv_fp8.reshape(original_shape)
kv_scale = kv_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1)
next_kv_fp8 = RING_COMM.send_recv(kv_fp8)
next_kv_scale = RING_COMM.send_recv(kv_scale)
else:
k_fp8, k_scale = quant_fp8_vllm(k.reshape(-1, hidden_dims))
v_fp8, v_scale = quant_fp8_vllm(v.reshape(-1, hidden_dims))
k_fp8 = k_fp8.reshape(original_shape)
v_fp8 = v_fp8.reshape(original_shape)
k_scale = k_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1)
v_scale = v_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1)
next_k_fp8 = RING_COMM.send_recv(k_fp8)
next_k_scale = RING_COMM.send_recv(k_scale)
next_v_fp8 = RING_COMM.send_recv(v_fp8)
next_v_scale = RING_COMM.send_recv(v_scale)
RING_COMM.commit()
else:
if use_kv_fusion:
next_kv = RING_COMM.send_recv(kv)
else:
next_k = RING_COMM.send_recv(k)
next_v = RING_COMM.send_recv(v)
RING_COMM.commit()

if step + 1 == world_size:
k = torch.cat((k, txt_k), dim=1)
v = torch.cat((v, txt_v), dim=1)

block_out, block_lse = self.ring_attn_sub(q, k, v)
if use_kv_fusion:
next_kv = torch.cat((kv, txt_kv), dim=1)
else:
k = torch.cat((k, txt_k), dim=1)
v = torch.cat((v, txt_v), dim=1)

if use_kv_fusion:
block_out, block_lse = self.ring_attn_sub_kv_fusion(q, kv)
else:
block_out, block_lse = self.ring_attn_sub(q, k, v)
out, lse = self.update_out_and_lse(out, lse, block_out, block_lse)

if step + 1 != world_size:
RING_COMM.wait()
k = next_k
v = next_v
if use_fp8_comm:
if use_kv_fusion:
kv = dequant_fp8_vllm(next_kv_fp8, next_kv_scale, original_dtype)
else:
k = dequant_fp8_vllm(next_k_fp8, next_k_scale, original_dtype)
v = dequant_fp8_vllm(next_v_fp8, next_v_scale, original_dtype)
else:
if use_kv_fusion:
kv = next_kv
else:
k = next_k
v = next_v
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The logic within the for loop contains significant code duplication, particularly in the handling of use_fp8_comm with and without use_kv_fusion. The quantization, communication, and dequantization steps are nearly identical for k, v, and the fused kv tensor. This repetition makes the code harder to read, maintain, and debug.

Consider refactoring this logic into smaller helper functions. For instance, you could have a helper for quantization and communication:

def _quant_and_send(tensor, original_shape, hidden_dims, comm):
    tensor_fp8, tensor_scale = quant_fp8_vllm(tensor.reshape(-1, hidden_dims))
    tensor_fp8 = tensor_fp8.reshape(original_shape)
    tensor_scale = tensor_scale.reshape(original_shape[0], original_shape[1], original_shape[2], 1)
    next_tensor_fp8 = comm.send_recv(tensor_fp8)
    next_tensor_scale = comm.send_recv(tensor_scale)
    return next_tensor_fp8, next_tensor_scale

A similar helper could be created for dequantization. This would abstract away the repeated logic and simplify the main loop's structure.

cur_rank = dist.get_rank(seq_p_group)
world_size = dist.get_world_size(seq_p_group)

img_qkv_len = slice_qkv_len
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The parameter img_qkv_len was renamed to slice_qkv_len in the method signature, which is a good change for generality. However, it's immediately reassigned to a local variable img_qkv_len. For consistency and to improve code clarity, it would be better to use slice_qkv_len throughout the function's body, removing this redundant assignment.

        # Consider removing this line and replacing all instances of `img_qkv_len`
        # in this function with `slice_qkv_len` for better clarity and consistency.
        img_qkv_len = slice_qkv_len

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To adapt to the project requirements, the input parameter must be named slice_qkv_len. In subsequent code, img_qkv_len used to differentiate between image and text components.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant