diff --git a/src/main/java/org/beehive/gpullama3/LlamaApp.java b/src/main/java/org/beehive/gpullama3/LlamaApp.java
index 7da9b878..822a082c 100644
--- a/src/main/java/org/beehive/gpullama3/LlamaApp.java
+++ b/src/main/java/org/beehive/gpullama3/LlamaApp.java
@@ -1,10 +1,8 @@
package org.beehive.gpullama3;
-import org.beehive.gpullama3.aot.AOT;
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.model.loader.ModelLoader;
import java.io.IOException;
diff --git a/src/main/java/org/beehive/gpullama3/aot/AOT.java b/src/main/java/org/beehive/gpullama3/aot/AOT.java
deleted file mode 100644
index 7fde18ca..00000000
--- a/src/main/java/org/beehive/gpullama3/aot/AOT.java
+++ /dev/null
@@ -1,85 +0,0 @@
-package org.beehive.gpullama3.aot;
-
-import org.beehive.gpullama3.auxiliary.Timer;
-import org.beehive.gpullama3.core.model.GGUF;
-import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
-import org.beehive.gpullama3.model.loader.LlamaModelLoader;
-import org.beehive.gpullama3.model.Model;
-import org.beehive.gpullama3.Options;
-import org.beehive.gpullama3.model.format.LlamaChatFormat;
-import org.beehive.gpullama3.model.llama.Llama;
-import org.beehive.gpullama3.inference.weights.Weights;
-import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer;
-
-import java.io.IOException;
-import java.nio.channels.FileChannel;
-import java.nio.file.Files;
-import java.nio.file.Path;
-import java.nio.file.StandardOpenOption;
-import java.util.Map;
-import java.util.Objects;
-
-/**
- * Support for AOT preloading of GGUF metadata with GraalVM's Native Image.
- *
- *
- * To preload a model at build time, pass {@code -Dllama.PreloadGGUF=/path/to/model.gguf}
- * to the native-image builder command. At runtime, the preloaded model will be used
- * iff the specified and preloaded file names (base name) match.
- */
-public final class AOT {
- AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
-
- static LlamaModelLoader modelLoader;
-
- record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map tensorInfos) {
- }
-
- private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF"));
-
- private static PartialModel preLoadGGUF(String modelPath) {
- if (modelPath == null || modelPath.isEmpty()) {
- return null;
- }
- try {
- Path path = Path.of(modelPath);
- if (!Files.exists(path) || !Files.isRegularFile(path)) {
- throw new IllegalArgumentException("Cannot pre-load model: " + path);
- }
- GGUF gguf = GGUF.loadModel(path);
- try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) {
- modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false, false);
- return new PartialModel(path.getFileName().toString(), modelLoader.loadModel(), // TODO: needs proper handling for AOT
- gguf.getTensorDataOffset(), gguf.getTensorInfos());
- }
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
- }
-
- /**
- * Tries to reuse a compatible AOT preloaded model.
- * The file name (base name) must match with the preloaded file name.
- * No checksum/hash is checked for performance reasons.
- */
- public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IOException {
- AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF;
- if (preLoaded == null) {
- return null; // no pre-loaded model stored
- }
- String optionsModel = modelPath.getFileName().toString();
- String preLoadedModel = preLoaded.modelFileName();
- if (!Objects.equals(optionsModel, preLoadedModel)) {
- // Preloaded and specified model file names didn't match.
- return null;
- }
- Llama baseModel = preLoaded.model();
- try (var timer = Timer.log("Load tensors from pre-loaded model"); var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) {
- // Load only the tensors (mmap slices).
- Map tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos());
- Weights weights = modelLoader.loadWeights(tensorEntries, baseModel.configuration());
- return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights, new LlamaChatFormat((LlamaTokenizer) baseModel.tokenizer()));
- }
- }
-}
-
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/FP16Weights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/FP16Weights.java
similarity index 95%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/FP16Weights.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/FP16Weights.java
index 90f419bd..c9ad8419 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/FP16Weights.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/FP16Weights.java
@@ -1,6 +1,7 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.fp16;
import org.beehive.gpullama3.core.model.GGMLType;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/LlamaTornadoWeights.java
similarity index 96%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/LlamaTornadoWeights.java
index 00f601b8..02550e00 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/LlamaTornadoWeights.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.fp16;
import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Phi3TornadoWeights.java
similarity index 97%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Phi3TornadoWeights.java
index 92410bf1..e6c12254 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Phi3TornadoWeights.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.fp16;
import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen2TornadoWeights.java
similarity index 96%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen2TornadoWeights.java
index 84617626..26c4d902 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen2TornadoWeights.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.fp16;
import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen3TornadoWeights.java
similarity index 97%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen3TornadoWeights.java
index 1236c121..06869323 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen3TornadoWeights.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.fp16;
import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeightsQ8_0.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Phi3TornadoWeightsQ8_0.java
similarity index 97%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeightsQ8_0.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Phi3TornadoWeightsQ8_0.java
index fbccd336..2a901acd 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeightsQ8_0.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Phi3TornadoWeightsQ8_0.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.q8_0;
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Q8_0Weights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Q8_0Weights.java
similarity index 95%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Q8_0Weights.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Q8_0Weights.java
index 04d4e11f..1de11ec4 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Q8_0Weights.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Q8_0Weights.java
@@ -1,7 +1,8 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.q8_0;
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
+import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
public class Q8_0Weights implements TornadoWeights {
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeightsQ8_0.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen2TornadoWeightsQ8_0.java
similarity index 96%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeightsQ8_0.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen2TornadoWeightsQ8_0.java
index 6cc29905..fb50b926 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeightsQ8_0.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen2TornadoWeightsQ8_0.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.q8_0;
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3Q8_0TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen3Q8_0TornadoWeights.java
similarity index 96%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3Q8_0TornadoWeights.java
rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen3Q8_0TornadoWeights.java
index c5dce240..aa6f0fe5 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3Q8_0TornadoWeights.java
+++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen3Q8_0TornadoWeights.java
@@ -1,4 +1,4 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.inference.weights.tornado.q8_0;
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java
new file mode 100644
index 00000000..3c3e4ea3
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java
@@ -0,0 +1,172 @@
+package org.beehive.gpullama3.model.loader;
+
+import org.beehive.gpullama3.core.model.GGUF;
+import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.core.types.Pair;
+import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.model.Model;
+import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
+import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+
+import java.io.IOException;
+import java.nio.channels.FileChannel;
+import java.util.Map;
+
+/**
+ * Abstract base class for model loaders using Template Method pattern. Provides common loading flow with extension points for model-specific logic.
+ *
+ * @param
+ * The specific Model type to load
+ * @param
+ * The specific Configuration type for the model
+ */
+public abstract class AbstractModelLoader {
+
+ protected final FileChannel fileChannel;
+ protected final GGUF gguf;
+ protected final int contextLength;
+ protected final boolean loadWeights;
+ protected final boolean useTornadovm;
+
+ protected Vocabulary vocabulary;
+
+ protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
+ this.fileChannel = fileChannel;
+ this.gguf = gguf;
+ this.contextLength = contextLength;
+ this.loadWeights = loadWeights;
+ this.useTornadovm = useTornadovm;
+ }
+
+ /**
+ * Template method that defines the model loading workflow. Subclasses should not override this method.
+ *
+ * @return The loaded model instance
+ */
+ public final M loadModel() {
+ try {
+ Map metadata = gguf.getMetadata();
+
+ // Step 1: Load vocabulary
+ this.vocabulary = loadVocabulary(metadata);
+
+ // Step 2: Create tokenizer
+ Tokenizer tokenizer = createTokenizer(metadata, vocabulary);
+
+ // Step 3: Create configuration
+ C config = createConfiguration(metadata);
+
+ // Step 4: Load weights (if requested)
+ Weights weights = null;
+ if (loadWeights) {
+ Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
+ weights = loadWeights(tensorEntries, config);
+ }
+
+ // Step 5: Create and return model instance
+ return createModel(config, tokenizer, weights);
+
+ } catch (IOException e) {
+ throw new ModelLoadException("Failed to load model", e);
+ }
+ }
+
+ /**
+ * Load the vocabulary from GGUF metadata. Model-specific implementations should override this method.
+ *
+ * @param metadata
+ * The GGUF metadata map
+ * @return The loaded Vocabulary
+ */
+ protected abstract Vocabulary loadVocabulary(Map metadata);
+
+ /**
+ * Create a tokenizer instance for this model.
+ *
+ * @param metadata
+ * The GGUF metadata map
+ * @param vocabulary
+ * The loaded vocabulary
+ * @return The tokenizer instance
+ */
+ protected abstract Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary);
+
+ /**
+ * Create a configuration instance from GGUF metadata.
+ *
+ * @param metadata
+ * The GGUF metadata map
+ * @return The configuration instance
+ */
+ protected abstract C createConfiguration(Map metadata);
+
+ /**
+ * Load model weights from tensor entries. Default implementation handles common weight loading logic.
+ *
+ * @param tensorEntries
+ * Map of tensor names to tensor entries
+ * @param config
+ * The model configuration
+ * @return The loaded weights
+ */
+ public Weights loadWeights(Map tensorEntries, C config) {
+ // Precompute RoPE frequencies
+ Pair ropeFreqs = precomputeRopeFrequencies(config);
+
+ // Get token embeddings and output weights
+ GGMLTensorEntry tokenEmbeddings = getTokenEmbeddings(tensorEntries);
+ GGMLTensorEntry outputWeight = getOutputWeight(tensorEntries, tokenEmbeddings);
+
+ // Delegate to specific implementation
+ if (useTornadovm) {
+ return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ } else {
+ return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ }
+ }
+
+ /**
+ * Create the final model instance.
+ *
+ * @param config
+ * The model configuration
+ * @param tokenizer
+ * The tokenizer
+ * @param weights
+ * The loaded weights
+ * @return The model instance
+ */
+ protected abstract M createModel(C config, Tokenizer tokenizer, Weights weights);
+
+ /**
+ * Precompute RoPE frequencies for this model. Default implementation can be overridden for custom RoPE configurations.
+ */
+ protected abstract Pair precomputeRopeFrequencies(C config);
+
+ /**
+ * Get token embeddings tensor entry. Default implementation can be overridden for different tensor naming.
+ */
+ protected GGMLTensorEntry getTokenEmbeddings(Map tensorEntries) {
+ return tensorEntries.get("token_embd.weight");
+ }
+
+ /**
+ * Get output weight tensor entry. Default implementation falls back to token embeddings if output.weight not found.
+ */
+ protected GGMLTensorEntry getOutputWeight(Map tensorEntries, GGMLTensorEntry tokenEmbeddings) {
+ return tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
+ }
+
+ /**
+ * Create standard (CPU) weights.
+ */
+ protected abstract Weights createStandardWeights(Map tensorEntries, C config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight);
+
+ /**
+ * Create TornadoVM (GPU) weights.
+ */
+ protected abstract Weights createTornadoVMWeights(Map tensorEntries, C config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight);
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java
index 79f35c92..1b55184f 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java
@@ -1,60 +1,149 @@
package org.beehive.gpullama3.model.loader;
-import org.beehive.gpullama3.auxiliary.Timer;
+import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.GGUF;
+import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.core.types.Pair;
+import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.LlamaTornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.llama.Llama;
import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer;
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
-public class LlamaModelLoader extends ModelLoader {
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
- public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadoVM) {
- super(fileChannel, gguf, contextLength, loadWeights, useTornadoVM);
+public class LlamaModelLoader extends AbstractModelLoader {
+
+ public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
+ super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
+ }
+
+ @Override
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return Vocabulary.loadLlamaVocabulary(metadata);
+ }
+
+ @Override
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ return new LlamaTokenizer(metadata, vocabulary);
+ }
+
+ @Override
+ protected LlamaConfiguration createConfiguration(Map metadata) {
+ int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");
+
+ return new LlamaConfiguration(
+ (int) metadata.get("llama.embedding_length"),
+ (int) metadata.get("llama.feed_forward_length"),
+ (int) metadata.get("llama.block_count"),
+ (int) metadata.get("llama.attention.head_count"),
+ metadata.containsKey("llama.attention.head_count_kv") ?
+ (int) metadata.get("llama.attention.head_count_kv")
+ : (int) metadata.get("llama.attention.head_count"),
+ vocabSize,
+ (int) metadata.get("llama.context_length"),
+ (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
+ (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength);
+ }
+
+ @Override
+ protected Pair precomputeRopeFrequencies(LlamaConfiguration config) {
+ return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength()
+ );
+ }
+
+ @Override
+ protected Llama createModel(LlamaConfiguration config, Tokenizer tokenizer, Weights weights) {
+ return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
}
- // @formatter:off
@Override
- public Llama loadModel() {
- try {
- Map metadata = gguf.getMetadata();
-
- Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata);
- Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary);
-
- LlamaConfiguration config = new LlamaConfiguration(
- (int) metadata.get("llama.embedding_length"),
- (int) metadata.get("llama.feed_forward_length"),
- (int) metadata.get("llama.block_count"),
- (int) metadata.get("llama.attention.head_count"),
-
- metadata.containsKey("llama.attention.head_count_kv") ?
- (int) metadata.get("llama.attention.head_count_kv") :
- (int) metadata.get("llama.attention.head_count"),
-
- vocabulary.size(),
- (int) metadata.get("llama.context_length"),
- (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
- (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
- ).withContextLength(contextLength);
-
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config);
- }
- return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
- } catch (IOException e) {
- throw new RuntimeException(e);
+ protected Weights createStandardWeights(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+
+ return new LlamaStandardWeights(
+ loadQuantized(tokenEmbeddings),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadQuantized(tensorEntries.get("output_norm.weight")),
+ new ArrayFloatTensor(ropeFreqs.first()),
+ new ArrayFloatTensor(ropeFreqs.second()),
+ loadQuantized(outputWeight),
+ outputWeight.ggmlType());
+ }
+
+ @Override
+ protected Weights createTornadoVMWeights(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
}
+
+ GGMLType ggmlType = outputWeight.ggmlType();
+ return switch(ggmlType) {
+ case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
+ };
+ }
+
+ private Weights createTornadoVMWeightsF16(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ return new LlamaTornadoWeights(
+ loadTensorAsFloatArray(tokenEmbeddings),
+ loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
+ FloatArray.fromArray(ropeFreqs.first()),
+ FloatArray.fromArray(ropeFreqs.second()),
+ loadTensorAsHalfFloatArray(outputWeight),
+ outputWeight.ggmlType()
+ );
+ }
+
+ private Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
+ return new Q8_0Weights(
+ loadTensorAsFloatArray(tokenEmbeddings),
+ loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
+ FloatArray.fromArray(ropeFreqs.first()),
+ FloatArray.fromArray(ropeFreqs.second()),
+ loadQ8_0QuantizedTensor(outputWeight),
+ outputWeight.ggmlType()
+ );
}
- // @formatter:on
}
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java
index efe64234..189db23f 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java
@@ -1,66 +1,152 @@
package org.beehive.gpullama3.model.loader;
-import org.beehive.gpullama3.auxiliary.Timer;
+import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.GGUF;
+import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
+import org.beehive.gpullama3.core.types.Pair;
+import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
+import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.LlamaTornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights;
import org.beehive.gpullama3.model.format.ChatFormat;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.model.mistral.Mistral;
import org.beehive.gpullama3.model.mistral.MistralConfiguration;
import org.beehive.gpullama3.tokenizer.impl.MistralTokenizer;
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
+import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
+import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
-public class MistralModelLoader extends ModelLoader {
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
+import static org.beehive.gpullama3.model.loader.ModelLoader.floatBufferToFloatArray;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsFloatArrayFromBuffer;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsQ8_0QuantizedTensor;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadQ8_0QuantizedTensor;
+
+public class MistralModelLoader extends AbstractModelLoader {
public MistralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
}
- // @formatter:off
@Override
- public Mistral loadModel() {
- try {
- Map metadata = gguf.getMetadata();
-
- Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata);
- Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary);
-
- int modelContextLength = (int) metadata.get("llama.context_length");
- if (contextLength < 0 || modelContextLength < contextLength) {
- contextLength = modelContextLength;
- }
-
- MistralConfiguration config = new MistralConfiguration(
- (int) metadata.get("llama.embedding_length"),
- (int) metadata.get("llama.feed_forward_length"),
- (int) metadata.get("llama.block_count"),
- (int) metadata.get("llama.attention.head_count"),
-
- metadata.containsKey("llama.attention.head_count_kv")
- ? (int) metadata.get("llama.attention.head_count_kv")
- : (int) metadata.get("llama.attention.head_count"),
-
- vocabulary.size(),
- contextLength,
- false,
- (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
- (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)
- );
-
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config);
- }
- return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
- } catch (IOException e) {
- throw new RuntimeException(e);
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return Vocabulary.loadMistralVocabulary(metadata);
+ }
+
+ @Override
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ return new MistralTokenizer(metadata, vocabulary);
+ }
+
+ @Override
+ protected MistralConfiguration createConfiguration(Map metadata) {
+ int modelContextLength = (int) metadata.get("llama.context_length");
+ int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength;
+
+ // Get vocabulary size from metadata
+ int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");
+
+ return new MistralConfiguration((int) metadata.get("llama.embedding_length"), (int) metadata.get("llama.feed_forward_length"), (int) metadata.get("llama.block_count"),
+ (int) metadata.get("llama.attention.head_count"),
+
+ metadata.containsKey("llama.attention.head_count_kv") ? (int) metadata.get("llama.attention.head_count_kv") : (int) metadata.get("llama.attention.head_count"),
+
+ vocabSize, finalContextLength, false, (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f),
+ (float) metadata.getOrDefault("llama.rope.freq_base", 10000f));
+ }
+
+ @Override
+ protected Pair precomputeRopeFrequencies(MistralConfiguration config) {
+ return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength()
+ );
+ }
+
+ @Override
+ protected Mistral createModel(MistralConfiguration config, Tokenizer tokenizer, Weights weights) {
+ return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null));
+ }
+
+ @Override
+ protected Weights createStandardWeights(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+
+ return new LlamaStandardWeights(
+ loadQuantized(tokenEmbeddings),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ loadQuantized(tensorEntries.get("output_norm.weight")),
+ new ArrayFloatTensor(ropeFreqs.first()),
+ new ArrayFloatTensor(ropeFreqs.second()),
+ loadQuantized(outputWeight),
+ outputWeight.ggmlType());
+ }
+
+ @Override
+ protected Weights createTornadoVMWeights(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
}
+
+ GGMLType ggmlType = outputWeight.ggmlType();
+ return switch(ggmlType) {
+ case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
+ };
+ }
+
+ private Weights createTornadoVMWeightsF16(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+
+ return new LlamaTornadoWeights(ModelLoader.loadTensorAsFloatArray(tokenEmbeddings),
+ ModelLoader.loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ ModelLoader.loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ ModelLoader.floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
+ FloatArray.fromArray(ropeFreqs.first()),
+ FloatArray.fromArray(ropeFreqs.second()),
+ ModelLoader.loadTensorAsHalfFloatArray(outputWeight),
+ outputWeight.ggmlType());
+ }
+
+ private Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
+ return new Q8_0Weights(
+ loadTensorAsFloatArray(tokenEmbeddings),
+ loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
+ loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
+ floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
+ FloatArray.fromArray(ropeFreqs.first()),
+ FloatArray.fromArray(ropeFreqs.second()),
+ loadQ8_0QuantizedTensor(outputWeight),
+ outputWeight.ggmlType()
+ );
}
- // @formatter:on
}
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java
new file mode 100644
index 00000000..f09ec56b
--- /dev/null
+++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java
@@ -0,0 +1,15 @@
+package org.beehive.gpullama3.model.loader;
+
+/**
+ * Exception thrown when model loading fails.
+ */
+public class ModelLoadException extends RuntimeException {
+
+ public ModelLoadException(String message) {
+ super(message);
+ }
+
+ public ModelLoadException(String message, Throwable cause) {
+ super(message, cause);
+ }
+}
\ No newline at end of file
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java
index 7d0b8dff..6b3c88eb 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java
@@ -1,32 +1,13 @@
package org.beehive.gpullama3.model.loader;
import org.beehive.gpullama3.Options;
-import org.beehive.gpullama3.aot.AOT;
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.GGUF;
-import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
-import org.beehive.gpullama3.core.model.tensor.F16FloatTensor;
-import org.beehive.gpullama3.core.model.tensor.F32FloatTensor;
-import org.beehive.gpullama3.core.model.tensor.FloatTensor;
-import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
-import org.beehive.gpullama3.core.model.tensor.Q4_0FloatTensor;
-import org.beehive.gpullama3.core.model.tensor.Q8_0FloatTensor;
-import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
-import org.beehive.gpullama3.core.types.Pair;
-import org.beehive.gpullama3.inference.operation.RoPE;
-import org.beehive.gpullama3.inference.weights.Weights;
-import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights;
-import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights;
-import org.beehive.gpullama3.inference.weights.tornado.Q8_0Weights;
-import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.core.model.tensor.*;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.ModelType;
-import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.HalfFloat;
-import uk.ac.manchester.tornado.api.types.arrays.ByteArray;
-import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
-import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-import uk.ac.manchester.tornado.api.types.arrays.Int8Array;
+import uk.ac.manchester.tornado.api.types.arrays.*;
import java.io.IOException;
import java.lang.foreign.MemorySegment;
@@ -39,9 +20,10 @@
import java.util.Map;
import java.util.function.IntFunction;
-public abstract class ModelLoader {
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsHalfFloatArray;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensorAsFloatArray;
- public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation
+public abstract class ModelLoader {
protected FileChannel fileChannel;
protected GGUF gguf;
@@ -99,13 +81,6 @@ private static ModelType detectModelType(Map metadata) {
* if AOT loading is enabled but the preloaded model is unavailable
*/
public static Model loadModel(Options options) throws IOException {
- if (USE_AOT) {
- Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens());
- if (model == null) {
- throw new IllegalStateException("Failed to load precompiled AOT model.");
- }
- return model;
- }
return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm());
}
@@ -119,6 +94,10 @@ public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeig
return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
}
+ /**
+ * Dispatcher method for loading a standard (non-tornado) tensor based on type.
+ * Used in CPU-path.
+ */
public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
GGMLType ggmlType = entry.ggmlType();
return switch (ggmlType) {
@@ -130,6 +109,55 @@ public static FloatTensor loadQuantized(GGMLTensorEntry entry) {
};
}
+ /**
+ * Dispatcher method for loading a standard tensor array based on type.
+ * Used in CPU-path.
+ */
+ public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction getTensorEntry) {
+ FloatTensor[] array = new FloatTensor[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = loadQuantized(getTensorEntry.apply(i));
+ }
+ return array;
+ }
+
+ /**
+ * [WIP]
+ * Dispatcher method for loading a TornadoVM tensor based on type.
+ * Used in GPU-path.
+ *
+ * TODO: fix this to follow loadQuantized logic
+ */
+ public static FloatTensor loadTornadoTensor(GGMLTensorEntry entry) {
+ GGMLType ggmlType = entry.ggmlType();
+ int size = FloatTensor.numberOfElements(entry.shape());
+ return switch (ggmlType) {
+// case F32 -> new F32QuantizedTensor(size, entry.memorySegment());
+ case Q8_0 -> loadQ8_0QuantizedTensor(entry);
+// case Q4_0 -> throw new UnsupportedOperationException("Not yet implemented");
+// //FloatTensor.numberOfElements(entry.shape()), entry.memorySegment()
+// case F16 -> new F16QuantizedTensor(size, entry.memorySegment());
+// /*{
+// HalfFloatArray array = new HalfFloatArray();
+// array.getSegment().copyFrom(entry.memorySegment());
+// // or array.getSegmentWithHeader() ?
+// }*/
+ default -> throw new UnsupportedOperationException("Quantization format " + ggmlType);
+ };
+ }
+
+ /**
+ * Dispatcher method for loading a TornadoVM tensor array based on type.
+ * Used in GPU-path.
+ */
+ public static FloatTensor[] loadTornadoTensorArray(int size, IntFunction getTensorEntry) {
+ FloatTensor[] array = new FloatTensor[size];
+ for (int i = 0; i < size; i++) {
+ array[i] = loadTornadoTensor(getTensorEntry.apply(i));
+ }
+ return array;
+ }
+
public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) {
FloatArray[] array = new FloatArray[size];
for (int i = 0; i < size; i++) {
@@ -216,6 +244,8 @@ public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) {
}
}
+ // TODO: rename to loadQ8_0Tensor
+ // move to a utils class
public static Q8_0QuantizedTensor loadQ8_0QuantizedTensor(GGMLTensorEntry entry) {
if (entry.ggmlType() != GGMLType.Q8_0) {
throw new IllegalArgumentException("Expected Q8_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name());
@@ -240,6 +270,7 @@ public static Q8_0QuantizedTensor loadQ8_0QuantizedTensor(GGMLTensorEntry entry)
ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE;
for (int block = 0; block < numBlocks; block++) {
+ // TODO: use GGML type method for the 34L size
long blockOffset = block * 34L; // 34 bytes per block
// read fp16 scale (first 2 bytes of block)
@@ -247,6 +278,7 @@ public static Q8_0QuantizedTensor loadQ8_0QuantizedTensor(GGMLTensorEntry entry)
scales.set(block, new HalfFloat(scaleRaw));
// read 32 int8 quantized values (remaining bytes of block)
+ // TODO: use GGML type method for the 32 size
for (int i = 0; i < 32; i++) {
byte quantValue = q8Segment.get(byteLayout, blockOffset + 2 + i);
quants.set(block * 32 + i, quantValue);
@@ -256,14 +288,6 @@ public static Q8_0QuantizedTensor loadQ8_0QuantizedTensor(GGMLTensorEntry entry)
return new Q8_0QuantizedTensor(size, scales, quants, q8Segment);
}
- public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction getTensorEntry) {
- FloatTensor[] array = new FloatTensor[size];
- for (int i = 0; i < size; i++) {
- array[i] = loadQuantized(getTensorEntry.apply(i));
- }
- return array;
- }
-
public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction getTensorEntry) {
FloatBuffer[] array = new FloatBuffer[size];
for (int i = 0; i < size; i++) {
@@ -282,104 +306,6 @@ public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) {
public abstract Model loadModel();
- //@formatter:off
- public Weights loadWeights(Map tensorEntries, Configuration config) {
- boolean ropeScaling = tensorEntries.containsKey("rope_freqs");
- RopeConfig ropeConfig = new RopeConfig(8.0f, // scaleFactor
- 1.0f, // loFreqFactor
- 3.0f, // hiFreqFactor
- 8192 // oldContextLength
- );
-
- Pair ropeFreqs = RoPE.precomputeFreqsCis(
- config.contextLength(), // Maximum sequence length the model can process
- config.headSize(), // Dimension of each attention head
- config.ropeTheta(), // Base frequency parameter (typically 10000.0)
- ropeScaling, // Whether to apply frequency scaling (determined by model type)
- ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling)
- ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies
- ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision
- ropeConfig.oldContextLength // Original context length the model was trained with
- );
-
- GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
- GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
-
- if (useTornadovm) {
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")");
- }
-
- if (outputWeight.ggmlType() == GGMLType.Q8_0) {
- return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- } else {
- return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
- } else {
- return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
- }
-
- public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- return new LlamaTornadoWeights(
- // Load directly to TornadoVM format
- loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
- FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()) {
- };
- }
-
- private Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) {
- return new Q8_0Weights(
- loadTensorAsFloatArray(tokenEmbeddings),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
- floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
- FloatArray.fromArray(ropeFreqs.first()),
- FloatArray.fromArray(ropeFreqs.second()),
- loadQ8_0QuantizedTensor(outputWeight),
- outputWeight.ggmlType()
- );
- }
-
- /**
- * Creates weights in standard format only
- */
- public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- return new LlamaStandardWeights(
- loadQuantized(tokenEmbeddings),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")),
- loadQuantized(tensorEntries.get("output_norm.weight")),
- new ArrayFloatTensor(ropeFreqs.first()),
- new ArrayFloatTensor(ropeFreqs.second()),
- loadQuantized(outputWeight),
- outputWeight.ggmlType());
- }
-
// Helper class to encapsulate RoPE configuration parameters
private static class RopeConfig {
final float scaleFactor;
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java
index 14b0dab7..9f118dfa 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java
@@ -1,18 +1,17 @@
package org.beehive.gpullama3.model.loader;
-import org.beehive.gpullama3.LlamaApp;
-import org.beehive.gpullama3.Options;
-import org.beehive.gpullama3.auxiliary.Timer;
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.GGUF;
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
+import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
import org.beehive.gpullama3.core.types.Pair;
import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights;
-import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
-import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeightsQ8_0;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.Phi3TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Phi3TornadoWeightsQ8_0;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.phi3.Phi3;
@@ -22,164 +21,185 @@
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
-public class Phi3ModelLoader extends ModelLoader {
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
+import static org.beehive.gpullama3.model.loader.ModelLoader.floatBufferToFloatArray;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsFloatArrayFromBuffer;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsHalfFloatArray;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsQ8_0QuantizedTensor;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadQ8_0QuantizedTensor;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensorAsFloatArray;
+import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensorAsHalfFloatArray;
+
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
+
+public class Phi3ModelLoader extends AbstractModelLoader {
+ private int modelContextLength;
+
public Phi3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
}
- // @formatter:off
@Override
- public Phi3 loadModel() {
- try {
- Map metadata = gguf.getMetadata();
- final String modelPrefix = "phi3.";
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return Vocabulary.loadPhi3Vocabulary(metadata);
+ }
- Vocabulary vocabulary = Vocabulary.loadPhi3Vocabulary(metadata);
+ @Override
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
Tokenizer tokenizer = new Phi3Tokenizer(metadata, vocabulary);
-
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Tokenizer: " + tokenizer.getClass().getSimpleName());
- }
-
- int modelContextLength = (int) metadata.get(modelPrefix + "context_length");
- if (contextLength < 0 || modelContextLength < contextLength) {
- contextLength = modelContextLength;
- }
-
- Phi3Configuration config = new Phi3Configuration(
- (int) metadata.get(modelPrefix + "embedding_length"), // dim
- (int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim
- (int) metadata.get(modelPrefix + "block_count"), // n_layers
- (int) metadata.get(modelPrefix + "attention.head_count"), // n_heads
-
- metadata.containsKey(modelPrefix + "attention.head_count_kv")
- ? (int) metadata.get(modelPrefix + "attention.head_count_kv")
- : (int) metadata.get(modelPrefix + "attention.head_count"), // n_kv_heads
-
- vocabulary.size(), // vocab_size
- contextLength, // context_length (user-specified, not model)
- (float) metadata.getOrDefault(modelPrefix + "attention.layer_norm_rms_epsilon", 1e-5f), // rms_norm_eps
- (float) metadata.getOrDefault(modelPrefix + "rope.freq_base", 10000f) // rope_theta
- );
-
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config, modelContextLength);
- }
-
- // Phi3 chat tokens
- ChatFormat.ChatTokens chatTokens = new ChatFormat.ChatTokens(
- "<|system|>", "<|end|>", "<|user|>", "<|end|>", "<|assistant|>"
- );
-
- return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
- } catch (IOException e) {
- throw new RuntimeException(e);
+ System.out.println("Tokenizer: " + tokenizer.getClass().getSimpleName());
+ return tokenizer;
}
+ return new Phi3Tokenizer(metadata, vocabulary);
}
- // @formatter:on
- // @formatter:off
- private Weights loadWeights(Map tensorEntries, Configuration config, int modelContextLength) {
+ @Override
+ protected Phi3Configuration createConfiguration(Map metadata) {
+ final String modelPrefix = "phi3.";
+ modelContextLength = (int) metadata.get(modelPrefix + "context_length");
+ int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength;
+
+ int vocabSize = metadata.containsKey(modelPrefix + "vocab_size") ? (int) metadata.get(modelPrefix + "vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length");
+
+ return new Phi3Configuration((int) metadata.get(modelPrefix + "embedding_length"), // dim
+ (int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim
+ (int) metadata.get(modelPrefix + "block_count"), // n_layers
+ (int) metadata.get(modelPrefix + "attention.head_count"), // n_heads
+
+ metadata.containsKey(modelPrefix + "attention.head_count_kv") ? (int) metadata.get(modelPrefix + "attention.head_count_kv") : (int) metadata.get(modelPrefix + "attention.head_count"), // n_kv_heads
+
+ vocabSize, // vocab_size
+ finalContextLength, // context_length (user-specified, not model)
+ (float) metadata.getOrDefault(modelPrefix + "attention.layer_norm_rms_epsilon", 1e-5f), // rms_norm_eps
+ (float) metadata.getOrDefault(modelPrefix + "rope.freq_base", 10000f) // rope_theta
+ );
+ }
+
+ @Override
+ protected Pair precomputeRopeFrequencies(Phi3Configuration config) {
// Calculate head size from dim and numberOfHeads
int headSize = config.dim() / config.numberOfHeads();
- Pair ropeFreqs = RoPE.precomputeFreqsCis(
- modelContextLength, // Use model context length for RoPE precomputation
+ return RoPE.precomputeFreqsCis(modelContextLength, // Use model context length for RoPE precomputation
headSize, // Calculated head size
- config.ropeTheta(),
- false, // Phi3 uses standard RoPE, not neox-style based on reference
+ config.ropeTheta(), false, // Phi3 uses standard RoPE, not neox-style based on reference
8, 1, 3, 8192 // Additional RoPE parameters from reference
);
+ }
- GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
- GGMLTensorEntry outputWeight = tensorEntries.get("output.weight"); // Phi3 always has separate output weight
-
- if (useTornadovm) {
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")");
- }
- if (outputWeight.ggmlType() == GGMLType.Q8_0) {
- return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- } else {
- return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
- } else {
- return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
+ @Override
+ protected Phi3 createModel(Phi3Configuration config, Tokenizer tokenizer, Weights weights) {
+ // Phi3 chat tokens
+ ChatFormat.ChatTokens chatTokens = new ChatFormat.ChatTokens("<|system|>", "<|end|>", "<|user|>", "<|end|>", "<|assistant|>");
+
+ return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
+ }
+
+ @Override
+ protected Weights createStandardWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ float[] ropeFreqsReal = ropeFreqs.first();
+ float[] ropeFreqsImag = ropeFreqs.second();
+
+ return new Phi3StandardWeights(
+ loadQuantized(tokenEmbeddings), // token_embedding_table
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[])
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined)
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[])
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined)
+ loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor)
+ new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real
+ new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag
+ loadQuantized(outputWeight), // wcls
+ outputWeight.ggmlType() // weightType
+ );
}
- // @formatter:on
- // @formatter:off
- public Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config,
- Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ @Override
+ protected Weights createTornadoVMWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
- return new Phi3TornadoWeightsQ8_0(
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
+ }
+
+ GGMLType ggmlType = outputWeight.ggmlType();
+ return switch(ggmlType) {
+ case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
+ };
+ }
+
+ private Weights createTornadoVMWeightsF16(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ return new Phi3TornadoWeights(
loadTensorAsFloatArray(tokenEmbeddings),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
- loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference)
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
+ loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference)
floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
FloatArray.fromArray(ropeFreqs.first()),
FloatArray.fromArray(ropeFreqs.second()),
- loadQ8_0QuantizedTensor(outputWeight),
+ loadTensorAsHalfFloatArray(outputWeight),
outputWeight.ggmlType()
);
}
- public Weights createTornadoVMWeights(Map tensorEntries, Configuration config,
- Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- return new Phi3TornadoWeights(
+ public Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config,
+ Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ return new Phi3TornadoWeightsQ8_0(
loadTensorAsFloatArray(tokenEmbeddings),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
- loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference)
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
+ loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference)
floatBufferToFloatArray(tensorEntries.get("output_norm.weight")),
FloatArray.fromArray(ropeFreqs.first()),
FloatArray.fromArray(ropeFreqs.second()),
- loadTensorAsHalfFloatArray(outputWeight),
+ loadQ8_0QuantizedTensor(outputWeight),
outputWeight.ggmlType()
);
}
- // @formatter:on
- // @formatter:off
- @Override
- public Weights createStandardWeights(Map tensorEntries,
- Configuration config,
- Pair ropeFreqs,
- GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- float[] ropeFreqsReal = ropeFreqs.first();
- float[] ropeFreqsImag = ropeFreqs.second();
+ // Helper methods
+ private FloatTensor[] loadLayerWeights(Map tensorEntries, Phi3Configuration config, String layerName, String suffix) {
+ FloatTensor[] weights = new FloatTensor[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.loadQuantized(tensorEntries.get(key));
+ }
+ return weights;
+ }
- return new Phi3StandardWeights(
- loadQuantized(tokenEmbeddings), // token_embedding_table
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[])
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined)
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[])
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined)
- loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor)
- new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real
- new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag
- loadQuantized(outputWeight), // wcls
- outputWeight.ggmlType() // weightType
- );
+ private FloatArray[] loadLayerWeightsAsFloatArraysFromBuffer(Map tensorEntries, Phi3Configuration config, String layerName, String suffix) {
+ FloatArray[] weights = new FloatArray[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.floatBufferToFloatArray(tensorEntries.get(key));
+ }
+ return weights;
+ }
+
+ private HalfFloatArray[] loadLayerWeightsAsHalfFloatArrays(Map tensorEntries, Phi3Configuration config, String layerName, String suffix) {
+ HalfFloatArray[] weights = new HalfFloatArray[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.loadTensorAsHalfFloatArray(tensorEntries.get(key));
+ }
+ return weights;
}
- // @formatter:on
}
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java
index fef3eb9d..bd31066e 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java
@@ -1,18 +1,18 @@
package org.beehive.gpullama3.model.loader;
-import org.beehive.gpullama3.Options;
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.GGUF;
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
+import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
import org.beehive.gpullama3.core.types.Pair;
import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeightsQ8_0;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.Qwen2TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen2TornadoWeightsQ8_0;
import org.beehive.gpullama3.model.Configuration;
-import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens;
import org.beehive.gpullama3.model.qwen2.Qwen2;
@@ -22,97 +22,76 @@
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
+import static org.beehive.gpullama3.core.model.GGMLType.F16;
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary;
-public class Qwen2ModelLoader extends ModelLoader {
+public class Qwen2ModelLoader extends AbstractModelLoader {
public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
}
@Override
- public Model loadModel() {
- Map metadata = gguf.getMetadata();
- String basename = (String) metadata.get("general.basename");
-
- String modelName = "DeepSeek-R1-Distill-Qwen".equals(basename) ? "DeepSeek-R1-Distill-Qwen" : "Qwen2.5";
-
- try {
- // reuse method of Qwen3
- Vocabulary vocabulary = loadQwen3Vocabulary(metadata);
- boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
- Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
-
- int modelContextLength = (int) metadata.get("qwen2.context_length");
- if (contextLength < 0 || modelContextLength < contextLength) {
- contextLength = modelContextLength;
- }
-
- int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count");
- Qwen2Configuration config = new Qwen2Configuration((int) metadata.get("qwen2.embedding_length"), // dim
- (int) metadata.get("qwen2.feed_forward_length"), // hiddendim
- (int) metadata.get("qwen2.block_count"), // numberOfLayers
- (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads
-
- numberOfKeyValueHeads, // numberOfKeyValueHeads
- numberOfKeyValueHeads, // numberOfHeadsKey
- numberOfKeyValueHeads, // numberOfHeadsValue
-
- vocabulary.size(), modelContextLength, contextLength, false, (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"), (float) metadata.get("qwen2.rope.freq_base"));
-
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config);
- }
- // Qwen2.5-Coder uses <|endoftext|> as stop-token.
- ChatTokens chatTokens = isDeepSeekR1DistillQwen
- ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
- : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
- return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return loadQwen3Vocabulary(metadata);
+ }
+
+ @Override
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
}
- // @formatter:off
@Override
- public Weights loadWeights(Map tensorEntries, Configuration config) {
- Pair ropeFreqs = RoPE.precomputeFreqsCis(
- config.contextLengthModel(),
- config.headSize(),
- config.ropeTheta(),
+ protected Qwen2Configuration createConfiguration(Map metadata) {
+ int modelContextLength = (int) metadata.get("qwen2.context_length");
+ int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength;
+
+ int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count");
+ int vocabSize = vocabulary.size();
+
+ return new Qwen2Configuration(
+ (int) metadata.get("qwen2.embedding_length"), // dim
+ (int) metadata.get("qwen2.feed_forward_length"), // hiddendim
+ (int) metadata.get("qwen2.block_count"), // numberOfLayers
+ (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads
+
+ numberOfKeyValueHeads, // numberOfKeyValueHeads
+ numberOfKeyValueHeads, // numberOfHeadsKey
+ numberOfKeyValueHeads, // numberOfHeadsValue
+
+ vocabSize,
+ modelContextLength,
+ finalContextLength,
false,
- 8,
- 1,
- 3,
- 8192
+ (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"),
+ (float) metadata.get("qwen2.rope.freq_base")
);
+ }
- GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
- GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
-
- if (useTornadovm) {
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")");
- }
- if (outputWeight.ggmlType() == GGMLType.Q8_0) {
- return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- } else {
- return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
- } else {
- return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
+ @Override
+ protected Pair precomputeRopeFrequencies(Qwen2Configuration config) {
+ return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.headSize(), config.ropeTheta(), false, 8, 1, 3, 8192);
}
@Override
- public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weights weights) {
+ Map metadata = gguf.getMetadata();
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ // Qwen2.5-Coder uses <|endoftext|> as stop-token.
+ ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
+ : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
+ return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
+ }
+
+ @Override
+ protected Weights createStandardWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
return new Qwen2StandardWeights(
loadQuantized(tokenEmbeddings),
@@ -120,11 +99,9 @@ public Weights createStandardWeights(Map tensorEntries,
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")),
-
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
-
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")),
@@ -134,12 +111,28 @@ public Weights createStandardWeights(Map tensorEntries,
new ArrayFloatTensor(ropeFreqs.first()),
new ArrayFloatTensor(ropeFreqs.second()),
loadQuantized(outputWeight),
- outputWeight.ggmlType());
+ outputWeight.ggmlType()
+ );
}
@Override
- public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ protected Weights createTornadoVMWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + F16 + ")");
+ }
+
+ GGMLType ggmlType = outputWeight.ggmlType();
+ return switch(ggmlType) {
+ case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ default ->
+ throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
+ };
+ }
+
+ private Weights createTornadoVMWeightsF16(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
return new Qwen2TornadoWeights(
loadTensorAsFloatArray(tokenEmbeddings),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
@@ -150,7 +143,6 @@ public Weights createTornadoVMWeights(Map tensorEntries
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")),
-
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")),
loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
@@ -164,8 +156,8 @@ public Weights createTornadoVMWeights(Map tensorEntries
);
}
- public Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
+ public Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
return new Qwen2TornadoWeightsQ8_0(
loadTensorAsFloatArray(tokenEmbeddings),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
@@ -189,6 +181,32 @@ public Weights createTornadoVMWeightsQ8_0(Map tensorEnt
outputWeight.ggmlType()
);
}
- // @formatter:on
+ // Helper methods
+ private FloatTensor[] loadLayerWeights(Map tensorEntries, Qwen2Configuration config, String layerName, String suffix) {
+ FloatTensor[] weights = new FloatTensor[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.loadQuantized(tensorEntries.get(key));
+ }
+ return weights;
+ }
+
+ private FloatArray[] loadLayerWeightsAsFloatArraysFromBuffer(Map tensorEntries, Qwen2Configuration config, String layerName, String suffix) {
+ FloatArray[] weights = new FloatArray[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.floatBufferToFloatArray(tensorEntries.get(key));
+ }
+ return weights;
+ }
+
+ private HalfFloatArray[] loadLayerWeightsAsHalfFloatArrays(Map tensorEntries, Qwen2Configuration config, String layerName, String suffix) {
+ HalfFloatArray[] weights = new HalfFloatArray[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.loadTensorAsHalfFloatArray(tensorEntries.get(key));
+ }
+ return weights;
+ }
}
diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java
index 682c7477..597292cd 100644
--- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java
+++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java
@@ -5,16 +5,18 @@
import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.GGUF;
import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor;
+import org.beehive.gpullama3.core.model.tensor.FloatTensor;
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
import org.beehive.gpullama3.core.types.Pair;
import org.beehive.gpullama3.inference.operation.RoPE;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen3Q8_0TornadoWeights;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
-import org.beehive.gpullama3.model.Configuration;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen3Q8_0TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.Qwen3TornadoWeights;
import org.beehive.gpullama3.model.format.ChatFormat;
import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens;
+import org.beehive.gpullama3.model.llama.LlamaConfiguration;
import org.beehive.gpullama3.model.qwen3.Qwen3;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer;
@@ -22,105 +24,120 @@
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;
import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
+import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;
-import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;
+import static org.beehive.gpullama3.model.loader.ModelLoader.*;
import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary;
-public class Qwen3ModelLoader extends ModelLoader {
+public class Qwen3ModelLoader extends AbstractModelLoader {
public Qwen3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
super(fileChannel, gguf, contextLength, loadWeights, useTornadovm);
}
- // @formatter:off
@Override
- public Qwen3 loadModel() {
- try {
- Map metadata = gguf.getMetadata();
-
- Vocabulary vocabulary = loadQwen3Vocabulary(metadata);
- boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
- Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
-
- int modelContextLength = (int) metadata.get("qwen3.context_length");
- if (contextLength < 0 || modelContextLength < contextLength) {
- contextLength = modelContextLength;
- }
-
- Qwen3Configuration config = new Qwen3Configuration(
- (int) metadata.get("qwen3.embedding_length"),
- (int) metadata.get("qwen3.feed_forward_length"),
- (int) metadata.get("qwen3.block_count"),
- (int) metadata.get("qwen3.attention.head_count"),
-
- metadata.containsKey("qwen3.attention.head_count_kv")
- ? (int) metadata.get("qwen3.attention.head_count_kv")
- : (int) metadata.get("qwen3.attention.head_count"),
- (int) metadata.get("qwen3.attention.key_length"),
- (int) metadata.get("qwen3.attention.value_length"),
-
- vocabulary.size(),
- modelContextLength, contextLength,
- false,
- (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"),
- (float) metadata.get("qwen3.rope.freq_base")
- );
-
- Weights weights = null;
- if (loadWeights) {
- Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
- weights = loadWeights(tensorEntries, config);
- }
- // Qwen2.5-coder uses <|endoftext|> as stop-token.
- ChatTokens chatTokens = isDeepSeekR1DistillQwen ?
- new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") :
- new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
- return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
- } catch (IOException e) {
- throw new RuntimeException(e);
- }
+ protected Vocabulary loadVocabulary(Map metadata) {
+ return loadQwen3Vocabulary(metadata);
}
- // @formatter:on
- // @formatter:off
@Override
- public Weights loadWeights(Map tensorEntries, Configuration config) {
- Pair ropeFreqs = RoPE.precomputeFreqsCis(
- config.contextLengthModel(),
- config.numberOfHeadsKey(),
- config.ropeTheta(),
+ protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) {
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen);
+ }
+
+ @Override
+ protected Qwen3Configuration createConfiguration(Map metadata) {
+ int modelContextLength = (int) metadata.get("qwen3.context_length");
+ int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength;
+
+ int vocabSize = vocabulary.size();
+
+ return new Qwen3Configuration(
+ (int) metadata.get("qwen3.embedding_length"),
+ (int) metadata.get("qwen3.feed_forward_length"),
+ (int) metadata.get("qwen3.block_count"),
+ (int) metadata.get("qwen3.attention.head_count"),
+
+ metadata.containsKey("qwen3.attention.head_count_kv") ?
+ (int) metadata.get("qwen3.attention.head_count_kv") :
+ (int) metadata.get("qwen3.attention.head_count"),
+ (int) metadata.get("qwen3.attention.key_length"),
+ (int) metadata.get("qwen3.attention.value_length"),
+
+ vocabSize,
+ modelContextLength,
+ finalContextLength,
false,
- 0,
- 0,
- 0,
- 0
+ (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"),
+ (float) metadata.get("qwen3.rope.freq_base")
);
+ }
- GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight");
- GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
-
- if (useTornadovm) {
- if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
- System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")");
- }
- if (outputWeight.ggmlType() == GGMLType.Q8_0) {
- return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- } else {
- return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
- } else {
- return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
- }
+ @Override
+ protected Pair precomputeRopeFrequencies(Qwen3Configuration config) {
+ return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.numberOfHeadsKey(), config.ropeTheta(), false, 0, 0, 0, 0);
+ }
+
+ @Override
+ protected Qwen3 createModel(Qwen3Configuration config, Tokenizer tokenizer, Weights weights) {
+ Map metadata = gguf.getMetadata();
+ boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename"));
+ // Qwen2.5-coder uses <|endoftext|> as stop-token.
+ ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "")
+ : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>");
+ return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens));
}
- // @formatter:on
- // @formatter:off
@Override
- public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ protected Weights createStandardWeights(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight) {
+ float[] ropeFreqsReal = ropeFreqs.first();
+ float[] ropeFreqsImag = ropeFreqs.second();
+ return new Qwen3StandardWeights(
+ loadQuantized(tokenEmbeddings),
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
+ loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
+ loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight
+ new ArrayFloatTensor(ropeFreqsReal),
+ new ArrayFloatTensor(ropeFreqsImag),
+ tensorEntries.containsKey("output.weight")
+ ? ModelLoader.loadQuantized(tensorEntries.get("output.weight"))
+ : loadQuantized(tokenEmbeddings), // weights are shared
+ null
+ );
+ }
+
+ @Override
+ protected Weights createTornadoVMWeights(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
+ if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) {
+ System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")");
+ }
+
+ GGMLType ggmlType = outputWeight.ggmlType();
+ return switch(ggmlType) {
+ case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
+ default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights.");
+ };
+
+ }
+
+ private Weights createTornadoVMWeightsF16(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
return new Qwen3TornadoWeights(
loadTensorAsFloatArray(tokenEmbeddings),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
@@ -142,8 +159,8 @@ public Weights createTornadoVMWeights(Map tensorEntries
);
}
- public Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
+ private Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings,
+ GGMLTensorEntry outputWeight) {
return new Qwen3Q8_0TornadoWeights(
loadTensorAsFloatArray(tokenEmbeddings),
loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")),
@@ -164,40 +181,32 @@ public Weights createTornadoVMWeightsQ8_0(Map tensorEnt
outputWeight.ggmlType()
);
}
- // @formatter:on
- // @formatter:off
- @Override
- public Weights createStandardWeights(Map tensorEntries,
- Configuration config,
- Pair ropeFreqs,
- GGMLTensorEntry tokenEmbeddings,
- GGMLTensorEntry outputWeight) {
- float[] ropeFreqsReal = ropeFreqs.first();
- float[] ropeFreqsImag = ropeFreqs.second();
- return new Qwen3StandardWeights(
- loadQuantized(tokenEmbeddings),
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo
+ // Helper methods
+ private FloatTensor[] loadLayerWeights(Map tensorEntries, Qwen3Configuration config, String layerName, String suffix) {
+ FloatTensor[] weights = new FloatTensor[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.loadQuantized(tensorEntries.get(key));
+ }
+ return weights;
+ }
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm
+ private FloatArray[] loadLayerWeightsAsFloatArraysFromBuffer(Map tensorEntries, Qwen3Configuration config, String layerName, String suffix) {
+ FloatArray[] weights = new FloatArray[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.floatBufferToFloatArray(tensorEntries.get(key));
+ }
+ return weights;
+ }
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2
- loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3
- loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight
- new ArrayFloatTensor(ropeFreqsReal),
- new ArrayFloatTensor(ropeFreqsImag),
- tensorEntries.containsKey("output.weight")
- ? ModelLoader.loadQuantized(tensorEntries.get("output.weight"))
- : loadQuantized(tokenEmbeddings), // weights are shared
- null
- );
+ private HalfFloatArray[] loadLayerWeightsAsHalfFloatArrays(Map tensorEntries, Qwen3Configuration config, String layerName, String suffix) {
+ HalfFloatArray[] weights = new HalfFloatArray[config.numberOfLayers()];
+ for (int i = 0; i < config.numberOfLayers(); i++) {
+ String key = String.format("blk.%d.%s.%s", i, layerName, suffix);
+ weights[i] = ModelLoader.loadTensorAsHalfFloatArray(tensorEntries.get(key));
+ }
+ return weights;
}
- // @formatter:on
}
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java
index 6cfdb821..6c5b0238 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.auxiliary.Tuple2;
import org.beehive.gpullama3.inference.state.Phi3State;
-import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.Phi3TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import uk.ac.manchester.tornado.api.GridScheduler;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java
index dbdd204a..0197a655 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.auxiliary.Tuple2;
import org.beehive.gpullama3.inference.state.Phi3State;
-import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeightsQ8_0;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Phi3TornadoWeightsQ8_0;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.phi3.Phi3Configuration;
import uk.ac.manchester.tornado.api.GridScheduler;
diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2Q8_0TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Q8_0TornadoVMLayerPlanner.java
similarity index 99%
rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2Q8_0TornadoVMLayerPlanner.java
rename to src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Q8_0TornadoVMLayerPlanner.java
index 4884e4af..1d109a04 100644
--- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2Q8_0TornadoVMLayerPlanner.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Q8_0TornadoVMLayerPlanner.java
@@ -1,7 +1,8 @@
-package org.beehive.gpullama3.inference.weights.tornado;
+package org.beehive.gpullama3.tornadovm;
import org.beehive.gpullama3.auxiliary.Tuple2;
import org.beehive.gpullama3.inference.state.Qwen2State;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen2TornadoWeightsQ8_0;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import org.beehive.gpullama3.tornadovm.Qwen2Kernels;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java
index 1f9d547b..e3155afa 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.auxiliary.Tuple2;
import org.beehive.gpullama3.inference.state.Qwen2State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.Qwen2TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen2.Qwen2Configuration;
import uk.ac.manchester.tornado.api.GridScheduler;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java
index fd294965..4942ce2f 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.auxiliary.Tuple2;
import org.beehive.gpullama3.inference.state.Qwen3State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen3Q8_0TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen3Q8_0TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import uk.ac.manchester.tornado.api.GridScheduler;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java
index 57d08a90..e04e8eef 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.auxiliary.Tuple2;
import org.beehive.gpullama3.inference.state.Qwen3State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.Qwen3TornadoWeights;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.qwen3.Qwen3Configuration;
import uk.ac.manchester.tornado.api.GridScheduler;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java
index 4849b847..02ccf272 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java
@@ -1,7 +1,7 @@
package org.beehive.gpullama3.tornadovm;
import org.beehive.gpullama3.auxiliary.Tuple2;
-import org.beehive.gpullama3.inference.weights.tornado.FP16Weights;
+import org.beehive.gpullama3.inference.weights.tornado.fp16.FP16Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.inference.state.State;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
index 1e420b1a..8cae8eac 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java
@@ -6,7 +6,6 @@
import org.beehive.gpullama3.inference.state.Qwen2State;
import org.beehive.gpullama3.inference.state.Qwen3State;
import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.inference.weights.tornado.Qwen2Q8_0TornadoVMLayerPlanner;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.ModelType;
diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMQ8_0LayerPlanner.java
index 347f3267..1173a694 100644
--- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMQ8_0LayerPlanner.java
+++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMQ8_0LayerPlanner.java
@@ -2,7 +2,7 @@
import org.beehive.gpullama3.auxiliary.Tuple2;
import org.beehive.gpullama3.inference.state.State;
-import org.beehive.gpullama3.inference.weights.tornado.Q8_0Weights;
+import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import uk.ac.manchester.tornado.api.GridScheduler;