Skip to content

Commit 851bb07

Browse files
Refine model loader refactoring and converge with Q8 support
1 parent a1359cb commit 851bb07

File tree

7 files changed

+391
-290
lines changed

7 files changed

+391
-290
lines changed

src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ public abstract class AbstractModelLoader<M extends Model, C extends Configurati
2929
protected final boolean loadWeights;
3030
protected final boolean useTornadovm;
3131

32+
protected Vocabulary vocabulary;
33+
3234
protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
3335
this.fileChannel = fileChannel;
3436
this.gguf = gguf;
@@ -47,7 +49,7 @@ public final M loadModel() {
4749
Map<String, Object> metadata = gguf.getMetadata();
4850

4951
// Step 1: Load vocabulary
50-
Vocabulary vocabulary = loadVocabulary(metadata);
52+
this.vocabulary = loadVocabulary(metadata);
5153

5254
// Step 2: Create tokenizer
5355
Tokenizer tokenizer = createTokenizer(metadata, vocabulary);
Lines changed: 77 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,29 @@
11
package org.beehive.gpullama3.model.loader;
22

3+
import org.beehive.gpullama3.core.model.GGMLType;
34
import org.beehive.gpullama3.core.model.GGUF;
45
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
5-
import org.beehive.gpullama3.core.model.tensor.FloatTensor;
66
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
77
import org.beehive.gpullama3.core.types.Pair;
88
import org.beehive.gpullama3.inference.operation.RoPE;
99
import org.beehive.gpullama3.inference.weights.Weights;
1010
import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights;
11-
import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
11+
import org.beehive.gpullama3.inference.weights.tornado.fp16.LlamaTornadoWeights;
12+
import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights;
1213
import org.beehive.gpullama3.model.format.ChatFormat;
1314
import org.beehive.gpullama3.model.llama.Llama;
1415
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
1516
import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer;
1617
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
1718
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
19+
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
1820
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
19-
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
2021

2122
import java.nio.channels.FileChannel;
2223
import java.util.Map;
2324

25+
import static org.beehive.gpullama3.model.loader.ModelLoader.*;
26+
2427
public class LlamaModelLoader extends AbstractModelLoader<Llama, LlamaConfiguration> {
2528

2629
public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
@@ -41,10 +44,17 @@ protected Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary voc
4144
protected LlamaConfiguration createConfiguration(Map<String, Object> metadata) {
4245
int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");
4346

44-
return new LlamaConfiguration((int) metadata.get("llama.embedding_length"), (int) metadata.get("llama.feed_forward_length"), (int) metadata.get("llama.block_count"),
47+
return new LlamaConfiguration(
48+
(int) metadata.get("llama.embedding_length"),
49+
(int) metadata.get("llama.feed_forward_length"),
50+
(int) metadata.get("llama.block_count"),
4551
(int) metadata.get("llama.attention.head_count"),
46-
metadata.containsKey("llama.attention.head_count_kv") ? (int) metadata.get("llama.attention.head_count_kv") : (int) metadata.get("llama.attention.head_count"), vocabSize,
47-
(int) metadata.get("llama.context_length"), (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
52+
metadata.containsKey("llama.attention.head_count_kv") ?
53+
(int) metadata.get("llama.attention.head_count_kv")
54+
: (int) metadata.get("llama.attention.head_count"),
55+
vocabSize,
56+
(int) metadata.get("llama.context_length"),
57+
(float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
4858
(float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength);
4959
}
5060

@@ -63,41 +73,77 @@ protected Llama createModel(LlamaConfiguration config, Tokenizer tokenizer, Weig
6373
protected Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, LlamaConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
6474
GGMLTensorEntry outputWeight) {
6575

66-
return new LlamaStandardWeights(ModelLoader.loadQuantized(tokenEmbeddings),
67-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
68-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
69-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
70-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
71-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
72-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
73-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
74-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
75-
ModelLoader.loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
76-
ModelLoader.loadQuantized(tensorEntries.get("output_norm.weight")),
76+
return new LlamaStandardWeights(
77+
loadQuantized(tokenEmbeddings),
78+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
79+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
80+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
81+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
82+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
83+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
84+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
85+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
86+
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
87+
loadQuantized(tensorEntries.get("output_norm.weight")),
7788
new ArrayFloatTensor(ropeFreqs.first()),
7889
new ArrayFloatTensor(ropeFreqs.second()),
79-
ModelLoader.loadQuantized(outputWeight),
90+
loadQuantized(outputWeight),
8091
outputWeight.ggmlType());
8192
}
8293

8394
@Override
8495
protected Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, LlamaConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
8596
GGMLTensorEntry outputWeight) {
97+
if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
98+
System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
99+
}
100+
101+
GGMLType ggmlType = outputWeight.ggmlType();
102+
return switch(ggmlType) {
103+
case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
104+
case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
105+
default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
106+
};
107+
}
86108

87-
return new LlamaTornadoWeights(ModelLoader.loadTensorAsFloatArray(tokenEmbeddings),
88-
ModelLoader.loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
89-
ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
90-
ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
91-
ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
92-
ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
93-
ModelLoader.loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
94-
ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
95-
ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
96-
ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
97-
ModelLoader.floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
109+
private Weights createTornadoVMWeightsF16(Map<String, GGMLTensorEntry> tensorEntries, LlamaConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
110+
GGMLTensorEntry outputWeight) {
111+
return new LlamaTornadoWeights(
112+
loadTensorAsFloatArray(tokenEmbeddings),
113+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
114+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
115+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
116+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
117+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
118+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
119+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
120+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
121+
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
122+
floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
98123
FloatArray.fromArray(ropeFreqs.first()),
99124
FloatArray.fromArray(ropeFreqs.second()),
100-
ModelLoader.loadTensorAsHalfFloatArray(outputWeight),
101-
outputWeight.ggmlType());
125+
loadTensorAsHalfFloatArray(outputWeight),
126+
outputWeight.ggmlType()
127+
);
128+
}
129+
130+
private Q8_0Weights createTornadoVMWeightsQ8_0(Map<String, GGMLTensorEntry> tensorEntries, LlamaConfiguration config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
131+
return new Q8_0Weights(
132+
loadTensorAsFloatArray(tokenEmbeddings),
133+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
134+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
135+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
136+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
137+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
138+
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
139+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
140+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
141+
loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
142+
floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
143+
FloatArray.fromArray(ropeFreqs.first()),
144+
FloatArray.fromArray(ropeFreqs.second()),
145+
loadQ8_0QuantizedTensor(outputWeight),
146+
outputWeight.ggmlType()
147+
);
102148
}
103149
}

0 commit comments

Comments
 (0)