Skip to content

Conversation

@GuoningHuang
Copy link

@GuoningHuang GuoningHuang commented Oct 20, 2025

This PR introduces an optimized implementation of Online LayerNorm
It provides both frontend operator fusion and the corresponding lowering pass to support end-to-end execution.

  1. Added an Online LayerNorm kernel based on the Welford algorithm for incremental mean/variance computation,achive 50x performance speed up compare with origin TOSA version.
  2. Implemented frontend LayerNorm fusion, combining operations (reduce_sum, rsqrt, mul, add, etc.) into a single fused operator.
  3. Added lowering pass from fused LayerNorm to optimized vectorized IR.
  4. Verified correctness on test_layernorm.py.
  5. Can be directly integrated into the existing Deepseek R1 inference pipeline.
image

@zhanghb97
Copy link
Member

What about the performance of E2E integration? Has the issue of performance decline in E2E been resolved?
Now we have the KV Cache support. You can use the latest version to integrate and test.

@GuoningHuang
Copy link
Author

What about the performance of E2E integration? Has the issue of performance decline in E2E been resolved? Now we have the KV Cache support. You can use the latest version to integrate and test.

OK,the optimization on this operator still faces a performance decline. I think the affine + vector optimization strategy may not work, I'm working on another optimization strategy.

@CBalaa
Copy link
Contributor

CBalaa commented Oct 29, 2025

  • examples/BuddyNext/next-layernorm-online.mlir actually performs operator fusion for RMSNorm, but it seems that you have deduced the wrong formula and did not correctly update the value of reduce_ sum_x_sq online. You may update reduce_ sum_x_sq as the following.

$$ \begin{align} s_{iJ} & = \sum_j^Jx_{ij}^2 = \sum_j^{J-1}x_{ij}^2 + x_{iJ}^2 = s_{iJ-1} + x_{iJ}^2 \\ o_{ij} & = \frac{x_{ij} \cdot g_{j}}{\sqrt{s_{ij} / N}} = o_{ij-1} + f(x_{ij}, s_{ij-1}) \end{align} $$

In addition, online algorithms can effectively fuse multiple reductions with dependency relationships and the same dimension, but the effect is not ideal when there is only one reduction. Therefore, I recommend that you try to fuse RMSNorm with its subsequent matmul.

  • Additionally, I noticed that the fusion pattern you added did not take effect, and tests/Python/test_layernorm.py is actually testing the correctness of the Tosa version.

@GuoningHuang
Copy link
Author

  • examples/BuddyNext/next-layernorm-online.mlir actually performs operator fusion for RMSNorm, but it seems that you have deduced the wrong formula and did not correctly update the value of reduce_ sum_x_sq online. You may update reduce_ sum_x_sq as the following.

s i J = ∑ j J x i j 2 = ∑ j J − 1 x i j 2 + x i J 2 = s i J − 1 + x i J 2 o i j = x i j ⋅ g j s i j / N = o i j − 1 + f ( x i j , s i j − 1 )

In addition, online algorithms can effectively fuse multiple reductions with dependency relationships and the same dimension, but the effect is not ideal when there is only one reduction. Therefore, I recommend that you try to fuse RMSNorm with its subsequent matmul.

  • Additionally, I noticed that the fusion pattern you added did not take effect, and tests/Python/test_layernorm.py is actually testing the correctness of the Tosa version.

Thank you for your advice. You are right — I misunderstood the online algorithm. I will try to fuse RMSNorm with its subsequent matmul as you suggested.

@zhanghb97
Copy link
Member

@GuoningHuang
I suggest not putting the fusion method in this PR and the existing fusion method (matmul transpose fusion) under the same option for now. The performance drop you mentioned in the issue is actually due to the fact that the matmul transpose fusion hasn't integrated vectorization optimization yet (there's currently a memory layout issue being fixed).

I have tried your PR. If you disable the matmul transpose fusion, performance won't drop, but it also doesn't improve. This might be related to the fusion strategy issue that @CBalaa mentioned, as well as the fact that the fused code hasn't been optimized for vectorization and parallelization, so scalar execution doesn't show a significant performance advantage. I'm not sure where the 50x performance gain you mentioned is coming from.

@CBalaa
Copy link
Contributor

CBalaa commented Oct 31, 2025

I implemented flash attention using mlir here. You can refer to the code here to add flash attention to the buddy frontend.

@GuoningHuang
Copy link
Author

@GuoningHuang I suggest not putting the fusion method in this PR and the existing fusion method (matmul transpose fusion) under the same option for now. The performance drop you mentioned in the issue is actually due to the fact that the matmul transpose fusion hasn't integrated vectorization optimization yet (there's currently a memory layout issue being fixed).

I have tried your PR. If you disable the matmul transpose fusion, performance won't drop, but it also doesn't improve. This might be related to the fusion strategy issue that @CBalaa mentioned, as well as the fact that the fused code hasn't been optimized for vectorization and parallelization, so scalar execution doesn't show a significant performance advantage. I'm not sure where the 50x performance gain you mentioned is coming from.

OK, thank you for your suggestion!
The performance data I mentioned was tested using next-layernorm.mlir and next-layernorm-online.mlir.
I think the improvement might be because I used an incremental (online) RMSNorm to approximate the standard RMSNorm, which reduces both computation and memory access overhead.
However, I found this approximation may lead to some loss in inference accuracy, so I plan to explore other optimization approaches.

@GuoningHuang
Copy link
Author

I implemented flash attention using mlir here. You can refer to the code here to add flash attention to the buddy frontend.

Thank you! I will try it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants