diff --git a/gemma/flash_attention.cc b/gemma/flash_attention.cc index a463dfee..e91e2cb5 100644 --- a/gemma/flash_attention.cc +++ b/gemma/flash_attention.cc @@ -17,7 +17,9 @@ #include #include +#include #include +#include #include #include "compression/types.h" // GEMMA_DISABLED_TARGETS diff --git a/gemma/flash_attention.h b/gemma/flash_attention.h index 236c7dc3..b8a70eaf 100644 --- a/gemma/flash_attention.h +++ b/gemma/flash_attention.h @@ -29,38 +29,39 @@ namespace gcpp { // Passed to HWY_VISIT_TARGETS; declares for one target. -#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ - namespace NAMESPACE { \ - void RMSNormAndPositionalEncoding( \ - size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ - const MatPtr& query_norm_scale, size_t layer_idx, \ - const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ - \ - void SingleFlashAttention(size_t start_pos, size_t last_pos, \ - const float* HWY_RESTRICT q, \ - const MatPtrT& k, const MatPtrT& v, \ - size_t layer_idx, \ - const AttentionActivationsPtrs& activations, \ - float* HWY_RESTRICT att_out, \ - ThreadingContext& ctx, size_t worker); \ - \ - Tile4FlashState TileFlashAttention4( \ - const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ - const MatPtrT& k, size_t start_pos, \ - const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ - size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ - const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ - MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ - ThreadingContext& ctx, const size_t worker); \ - \ - size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ - size_t total_tasks, size_t target_parallelism); \ - \ - void FlashAttention(size_t num_tokens, size_t target_parallelism, \ - size_t layer_idx, const MatPtr& query_norm_scale, \ - AttentionActivationsPtrs& activations, QBatch& qbatch, \ - ThreadingContext& ctx); \ - /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ +#define GEMMA_DECL_FLASH_ATTENTION(TARGET, NAMESPACE) \ + namespace NAMESPACE { \ + void RMSNormAndPositionalEncoding( \ + size_t num_tokens, const QBatch& qbatch, MatPtrT& q, \ + const MatPtr& query_norm_scale, size_t layer_idx, \ + const AttentionActivationsPtrs& activations, ThreadingContext& ctx); \ + \ + void SingleFlashAttention(size_t start_pos, size_t last_pos, \ + const BF16* HWY_RESTRICT q, \ + const MatPtrT& k, const MatPtrT& v, \ + size_t layer_idx, \ + const AttentionActivationsPtrs& activations, \ + float* HWY_RESTRICT att_out, \ + ThreadingContext& ctx, size_t worker); \ + \ + Tile4FlashState TileFlashAttention4( \ + const MatPtrT& q, const uint32_t* HWY_RESTRICT q_offsets, \ + const MatPtrT& k, size_t start_pos, \ + const uint32_t* HWY_RESTRICT last_pos, size_t min_last_pos, \ + size_t max_last_pos, const MatPtrT& v, size_t layer_idx, \ + const LayerWeightsPtrs& layer, const AttentionActivations& activations, \ + MatPtrT& att_out, const uint32_t* HWY_RESTRICT out_offsets, \ + ThreadingContext& ctx, const size_t worker); \ + \ + size_t GetVTileSize(size_t kNF, size_t num_head_groups, size_t num_tokens, \ + size_t total_tasks, size_t target_parallelism); \ + \ + void FlashAttention(size_t num_tokens, size_t target_parallelism, \ + size_t layer_idx, const MatPtr& query_norm_scale, \ + AttentionActivationsPtrs& activations, QBatch& qbatch, \ + ThreadingContext& ctx); \ + \ + /* NOLINTNEXTLINE(google-readability-namespace-comments) */ \ } // namespace NAMESPACE // Function declarations for each SIMD target. Allows direct call from the diff --git a/ops/ops-inl.h b/ops/ops-inl.h index 3b41ff34..0eeec31b 100644 --- a/ops/ops-inl.h +++ b/ops/ops-inl.h @@ -25,6 +25,7 @@ #include #include #include // std::enable_if_t +#include #include #include "ops/matmul.h"