11package org .beehive .gpullama3 .model .loader ;
22
3+ import org .beehive .gpullama3 .core .model .GGMLType ;
34import org .beehive .gpullama3 .core .model .GGUF ;
45import org .beehive .gpullama3 .core .model .tensor .ArrayFloatTensor ;
5- import org .beehive .gpullama3 .core .model .tensor .FloatTensor ;
66import org .beehive .gpullama3 .core .model .tensor .GGMLTensorEntry ;
77import org .beehive .gpullama3 .core .types .Pair ;
88import org .beehive .gpullama3 .inference .operation .RoPE ;
99import org .beehive .gpullama3 .inference .weights .Weights ;
1010import org .beehive .gpullama3 .inference .weights .standard .LlamaStandardWeights ;
11- import org .beehive .gpullama3 .inference .weights .tornado .LlamaTornadoWeights ;
11+ import org .beehive .gpullama3 .inference .weights .tornado .fp16 .LlamaTornadoWeights ;
12+ import org .beehive .gpullama3 .inference .weights .tornado .q8_0 .Q8_0Weights ;
1213import org .beehive .gpullama3 .model .format .ChatFormat ;
1314import org .beehive .gpullama3 .model .llama .Llama ;
1415import org .beehive .gpullama3 .model .llama .LlamaConfiguration ;
1516import org .beehive .gpullama3 .tokenizer .impl .LlamaTokenizer ;
1617import org .beehive .gpullama3 .tokenizer .impl .Tokenizer ;
1718import org .beehive .gpullama3 .tokenizer .vocabulary .Vocabulary ;
19+ import org .beehive .gpullama3 .tornadovm .TornadoVMMasterPlan ;
1820import uk .ac .manchester .tornado .api .types .arrays .FloatArray ;
19- import uk .ac .manchester .tornado .api .types .arrays .HalfFloatArray ;
2021
2122import java .nio .channels .FileChannel ;
2223import java .util .Map ;
2324
25+ import static org .beehive .gpullama3 .model .loader .ModelLoader .*;
26+
2427public class LlamaModelLoader extends AbstractModelLoader <Llama , LlamaConfiguration > {
2528
2629 public LlamaModelLoader (FileChannel fileChannel , GGUF gguf , int contextLength , boolean loadWeights , boolean useTornadovm ) {
@@ -41,10 +44,17 @@ protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary voc
4144 protected LlamaConfiguration createConfiguration (Map <String , Object > metadata ) {
4245 int vocabSize = metadata .containsKey ("llama.vocab_size" ) ? (int ) metadata .get ("llama.vocab_size" ) : (int ) metadata .get ("tokenizer.ggml.tokens.length" );
4346
44- return new LlamaConfiguration ((int ) metadata .get ("llama.embedding_length" ), (int ) metadata .get ("llama.feed_forward_length" ), (int ) metadata .get ("llama.block_count" ),
47+ return new LlamaConfiguration (
48+ (int ) metadata .get ("llama.embedding_length" ),
49+ (int ) metadata .get ("llama.feed_forward_length" ),
50+ (int ) metadata .get ("llama.block_count" ),
4551 (int ) metadata .get ("llama.attention.head_count" ),
46- metadata .containsKey ("llama.attention.head_count_kv" ) ? (int ) metadata .get ("llama.attention.head_count_kv" ) : (int ) metadata .get ("llama.attention.head_count" ), vocabSize ,
47- (int ) metadata .get ("llama.context_length" ), (float ) metadata .getOrDefault ("llama.attention.layer_norm_rms_epsilon" , 1e-5f ),
52+ metadata .containsKey ("llama.attention.head_count_kv" ) ?
53+ (int ) metadata .get ("llama.attention.head_count_kv" )
54+ : (int ) metadata .get ("llama.attention.head_count" ),
55+ vocabSize ,
56+ (int ) metadata .get ("llama.context_length" ),
57+ (float ) metadata .getOrDefault ("llama.attention.layer_norm_rms_epsilon" , 1e-5f ),
4858 (float ) metadata .getOrDefault ("llama.rope.freq_base" , 10000f )).withContextLength (contextLength );
4959 }
5060
@@ -63,41 +73,77 @@ protected Llama createModel(LlamaConfiguration config, Tokenizer tokenizer, Weig
6373 protected Weights createStandardWeights (Map <String , GGMLTensorEntry > tensorEntries , LlamaConfiguration config , Pair <float [], float []> ropeFreqs , GGMLTensorEntry tokenEmbeddings ,
6474 GGMLTensorEntry outputWeight ) {
6575
66- return new LlamaStandardWeights (ModelLoader .loadQuantized (tokenEmbeddings ),
67- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_norm.weight" )),
68- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_q.weight" )),
69- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_k.weight" )),
70- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_v.weight" )),
71- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_output.weight" )),
72- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_norm.weight" )),
73- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_gate.weight" )),
74- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_down.weight" )),
75- ModelLoader .loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_up.weight" )),
76- ModelLoader .loadQuantized (tensorEntries .get ("output_norm.weight" )),
76+ return new LlamaStandardWeights (
77+ loadQuantized (tokenEmbeddings ),
78+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_norm.weight" )),
79+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_q.weight" )),
80+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_k.weight" )),
81+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_v.weight" )),
82+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_output.weight" )),
83+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_norm.weight" )),
84+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_gate.weight" )),
85+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_down.weight" )),
86+ loadArrayOfQuantized (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_up.weight" )),
87+ loadQuantized (tensorEntries .get ("output_norm.weight" )),
7788 new ArrayFloatTensor (ropeFreqs .first ()),
7889 new ArrayFloatTensor (ropeFreqs .second ()),
79- ModelLoader . loadQuantized (outputWeight ),
90+ loadQuantized (outputWeight ),
8091 outputWeight .ggmlType ());
8192 }
8293
8394 @ Override
8495 protected Weights createTornadoVMWeights (Map <String , GGMLTensorEntry > tensorEntries , LlamaConfiguration config , Pair <float [], float []> ropeFreqs , GGMLTensorEntry tokenEmbeddings ,
8596 GGMLTensorEntry outputWeight ) {
97+ if (TornadoVMMasterPlan .ENABLE_TORNADOVM_INIT_TIME ) {
98+ System .out .println ("Loading model weights in TornadoVM format (loading " + outputWeight .ggmlType () + " -> " + GGMLType .F16 + ")" );
99+ }
100+
101+ GGMLType ggmlType = outputWeight .ggmlType ();
102+ return switch (ggmlType ) {
103+ case F16 -> createTornadoVMWeightsF16 (tensorEntries , config , ropeFreqs , tokenEmbeddings , outputWeight );
104+ case Q8_0 -> createTornadoVMWeightsQ8_0 (tensorEntries , config , ropeFreqs , tokenEmbeddings , outputWeight );
105+ default -> throw new UnsupportedOperationException ("Type: " + ggmlType + " currently not supported for TornadoVM weights." );
106+ };
107+ }
86108
87- return new LlamaTornadoWeights (ModelLoader .loadTensorAsFloatArray (tokenEmbeddings ),
88- ModelLoader .loadArrayAsFloatArrayFromBuffer (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_norm.weight" )),
89- ModelLoader .loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_q.weight" )),
90- ModelLoader .loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_k.weight" )),
91- ModelLoader .loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_v.weight" )),
92- ModelLoader .loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_output.weight" )),
93- ModelLoader .loadArrayAsFloatArrayFromBuffer (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_norm.weight" )),
94- ModelLoader .loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_gate.weight" )),
95- ModelLoader .loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_down.weight" )),
96- ModelLoader .loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_up.weight" )),
97- ModelLoader .floatBufferToFloatArray (tensorEntries .get ("output_norm.weight" )),
109+ private Weights createTornadoVMWeightsF16 (Map <String , GGMLTensorEntry > tensorEntries , LlamaConfiguration config , Pair <float [], float []> ropeFreqs , GGMLTensorEntry tokenEmbeddings ,
110+ GGMLTensorEntry outputWeight ) {
111+ return new LlamaTornadoWeights (
112+ loadTensorAsFloatArray (tokenEmbeddings ),
113+ loadArrayAsFloatArrayFromBuffer (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_norm.weight" )),
114+ loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_q.weight" )),
115+ loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_k.weight" )),
116+ loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_v.weight" )),
117+ loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_output.weight" )),
118+ loadArrayAsFloatArrayFromBuffer (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_norm.weight" )),
119+ loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_gate.weight" )),
120+ loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_down.weight" )),
121+ loadArrayAsHalfFloatArray (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_up.weight" )),
122+ floatBufferToFloatArray (tensorEntries .get ("output_norm.weight" )),
98123 FloatArray .fromArray (ropeFreqs .first ()),
99124 FloatArray .fromArray (ropeFreqs .second ()),
100- ModelLoader .loadTensorAsHalfFloatArray (outputWeight ),
101- outputWeight .ggmlType ());
125+ loadTensorAsHalfFloatArray (outputWeight ),
126+ outputWeight .ggmlType ()
127+ );
128+ }
129+
130+ private Q8_0Weights createTornadoVMWeightsQ8_0 (Map <String , GGMLTensorEntry > tensorEntries , LlamaConfiguration config , Pair <float [], float []> ropeFreqs , GGMLTensorEntry tokenEmbeddings , GGMLTensorEntry outputWeight ) {
131+ return new Q8_0Weights (
132+ loadTensorAsFloatArray (tokenEmbeddings ),
133+ loadArrayAsFloatArrayFromBuffer (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_norm.weight" )),
134+ loadArrayAsQ8_0QuantizedTensor (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_q.weight" )),
135+ loadArrayAsQ8_0QuantizedTensor (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_k.weight" )),
136+ loadArrayAsQ8_0QuantizedTensor (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_v.weight" )),
137+ loadArrayAsQ8_0QuantizedTensor (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".attn_output.weight" )),
138+ loadArrayAsFloatArrayFromBuffer (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_norm.weight" )),
139+ loadArrayAsQ8_0QuantizedTensor (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_gate.weight" )),
140+ loadArrayAsQ8_0QuantizedTensor (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_down.weight" )),
141+ loadArrayAsQ8_0QuantizedTensor (config .numberOfLayers (), i -> tensorEntries .get ("blk." + i + ".ffn_up.weight" )),
142+ floatBufferToFloatArray (tensorEntries .get ("output_norm.weight" )),
143+ FloatArray .fromArray (ropeFreqs .first ()),
144+ FloatArray .fromArray (ropeFreqs .second ()),
145+ loadQ8_0QuantizedTensor (outputWeight ),
146+ outputWeight .ggmlType ()
147+ );
102148 }
103149}
0 commit comments