File tree Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Expand file tree Collapse file tree 1 file changed +8
-2
lines changed Original file line number Diff line number Diff line change 2323#define GEMMA_MAX_SEQLEN 4096
2424#endif // !GEMMA_MAX_SEQLEN
2525
26+ // Allow changing k parameter of `SampleTopK` as a compiler flag
27+ #ifndef GEMMA_TOPK
28+ #define GEMMA_TOPK 1
29+ #endif // !GEMMA_TOPK
30+
2631#include < stddef.h>
2732
2833namespace gcpp {
2934
3035static constexpr size_t kSeqLen = GEMMA_MAX_SEQLEN;
36+ static constexpr size_t kTopK = GEMMA_TOPK;
3137
3238struct ConfigGemma7B {
3339 static constexpr int kSeqLen = gcpp::kSeqLen ;
@@ -38,7 +44,7 @@ struct ConfigGemma7B {
3844 static constexpr int kHeads = 16 ;
3945 static constexpr int kKVHeads = 16 ; // standard MHA
4046 static constexpr int kQKVDim = 256 ; // query size == key size == value size
41- static constexpr int kTopK = 1 ;
47+ static constexpr int kTopK = gcpp:: kTopK ;
4248};
4349
4450struct ConfigGemma2B {
@@ -50,7 +56,7 @@ struct ConfigGemma2B {
5056 static constexpr int kHeads = 8 ;
5157 static constexpr int kKVHeads = 8 ; // TODO(austinvhuang): add MQA support
5258 static constexpr int kQKVDim = 256 ; // query size == key size == value size
53- static constexpr int kTopK = 1 ;
59+ static constexpr int kTopK = gcpp:: kTopK ;
5460};
5561
5662} // namespace gcpp
You can’t perform that action at this time.
0 commit comments