diff --git a/src/main/java/org/beehive/gpullama3/LlamaApp.java b/src/main/java/org/beehive/gpullama3/LlamaApp.java index 7da9b878..822a082c 100644 --- a/src/main/java/org/beehive/gpullama3/LlamaApp.java +++ b/src/main/java/org/beehive/gpullama3/LlamaApp.java @@ -1,10 +1,8 @@ package org.beehive.gpullama3; -import org.beehive.gpullama3.aot.AOT; import org.beehive.gpullama3.auxiliary.LastRunMetrics; import org.beehive.gpullama3.inference.sampler.Sampler; import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.model.loader.ModelLoader; import java.io.IOException; diff --git a/src/main/java/org/beehive/gpullama3/aot/AOT.java b/src/main/java/org/beehive/gpullama3/aot/AOT.java deleted file mode 100644 index 7fde18ca..00000000 --- a/src/main/java/org/beehive/gpullama3/aot/AOT.java +++ /dev/null @@ -1,85 +0,0 @@ -package org.beehive.gpullama3.aot; - -import org.beehive.gpullama3.auxiliary.Timer; -import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; -import org.beehive.gpullama3.model.loader.LlamaModelLoader; -import org.beehive.gpullama3.model.Model; -import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.model.format.LlamaChatFormat; -import org.beehive.gpullama3.model.llama.Llama; -import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer; - -import java.io.IOException; -import java.nio.channels.FileChannel; -import java.nio.file.Files; -import java.nio.file.Path; -import java.nio.file.StandardOpenOption; -import java.util.Map; -import java.util.Objects; - -/** - * Support for AOT preloading of GGUF metadata with GraalVM's Native Image. - * - *

- * To preload a model at build time, pass {@code -Dllama.PreloadGGUF=/path/to/model.gguf} - * to the native-image builder command. At runtime, the preloaded model will be used - * iff the specified and preloaded file names (base name) match. - */ -public final class AOT { - AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; - - static LlamaModelLoader modelLoader; - - record PartialModel(String modelFileName, Llama model, long tensorDataOffset, Map tensorInfos) { - } - - private static final PartialModel PRELOADED_GGUF = preLoadGGUF(System.getProperty("llama.PreloadGGUF")); - - private static PartialModel preLoadGGUF(String modelPath) { - if (modelPath == null || modelPath.isEmpty()) { - return null; - } - try { - Path path = Path.of(modelPath); - if (!Files.exists(path) || !Files.isRegularFile(path)) { - throw new IllegalArgumentException("Cannot pre-load model: " + path); - } - GGUF gguf = GGUF.loadModel(path); - try (FileChannel fileChannel = FileChannel.open(path, StandardOpenOption.READ)) { - modelLoader = new LlamaModelLoader(fileChannel, gguf, Options.DEFAULT_MAX_TOKENS, false, false); - return new PartialModel(path.getFileName().toString(), modelLoader.loadModel(), // TODO: needs proper handling for AOT - gguf.getTensorDataOffset(), gguf.getTensorInfos()); - } - } catch (IOException e) { - throw new RuntimeException(e); - } - } - - /** - * Tries to reuse a compatible AOT preloaded model. - * The file name (base name) must match with the preloaded file name. - * No checksum/hash is checked for performance reasons. - */ - public static Model tryUsePreLoaded(Path modelPath, int contextLength) throws IOException { - AOT.PartialModel preLoaded = AOT.PRELOADED_GGUF; - if (preLoaded == null) { - return null; // no pre-loaded model stored - } - String optionsModel = modelPath.getFileName().toString(); - String preLoadedModel = preLoaded.modelFileName(); - if (!Objects.equals(optionsModel, preLoadedModel)) { - // Preloaded and specified model file names didn't match. - return null; - } - Llama baseModel = preLoaded.model(); - try (var timer = Timer.log("Load tensors from pre-loaded model"); var fileChannel = FileChannel.open(modelPath, StandardOpenOption.READ)) { - // Load only the tensors (mmap slices). - Map tensorEntries = GGUF.loadTensors(fileChannel, preLoaded.tensorDataOffset(), preLoaded.tensorInfos()); - Weights weights = modelLoader.loadWeights(tensorEntries, baseModel.configuration()); - return new Llama(baseModel.configuration().withContextLength(contextLength), baseModel.tokenizer(), weights, new LlamaChatFormat((LlamaTokenizer) baseModel.tokenizer())); - } - } -} - diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/FP16Weights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/FP16Weights.java similarity index 95% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/FP16Weights.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/FP16Weights.java index 90f419bd..c9ad8419 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/FP16Weights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/FP16Weights.java @@ -1,6 +1,7 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.fp16; import org.beehive.gpullama3.core.model.GGMLType; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/LlamaTornadoWeights.java similarity index 96% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/LlamaTornadoWeights.java index 00f601b8..02550e00 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/LlamaTornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/LlamaTornadoWeights.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.fp16; import org.beehive.gpullama3.core.model.GGMLType; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Phi3TornadoWeights.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Phi3TornadoWeights.java index 92410bf1..e6c12254 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Phi3TornadoWeights.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.fp16; import org.beehive.gpullama3.core.model.GGMLType; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen2TornadoWeights.java similarity index 96% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen2TornadoWeights.java index 84617626..26c4d902 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen2TornadoWeights.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.fp16; import org.beehive.gpullama3.core.model.GGMLType; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen3TornadoWeights.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen3TornadoWeights.java index 1236c121..06869323 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3TornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/fp16/Qwen3TornadoWeights.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.fp16; import org.beehive.gpullama3.core.model.GGMLType; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeightsQ8_0.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Phi3TornadoWeightsQ8_0.java similarity index 97% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeightsQ8_0.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Phi3TornadoWeightsQ8_0.java index fbccd336..2a901acd 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Phi3TornadoWeightsQ8_0.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Phi3TornadoWeightsQ8_0.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.q8_0; import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Q8_0Weights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Q8_0Weights.java similarity index 95% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Q8_0Weights.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Q8_0Weights.java index 04d4e11f..1de11ec4 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Q8_0Weights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Q8_0Weights.java @@ -1,7 +1,8 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.q8_0; import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor; +import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; public class Q8_0Weights implements TornadoWeights { diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeightsQ8_0.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen2TornadoWeightsQ8_0.java similarity index 96% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeightsQ8_0.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen2TornadoWeightsQ8_0.java index 6cc29905..fb50b926 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2TornadoWeightsQ8_0.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen2TornadoWeightsQ8_0.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.q8_0; import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3Q8_0TornadoWeights.java b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen3Q8_0TornadoWeights.java similarity index 96% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3Q8_0TornadoWeights.java rename to src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen3Q8_0TornadoWeights.java index c5dce240..aa6f0fe5 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen3Q8_0TornadoWeights.java +++ b/src/main/java/org/beehive/gpullama3/inference/weights/tornado/q8_0/Qwen3Q8_0TornadoWeights.java @@ -1,4 +1,4 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.inference.weights.tornado.q8_0; import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor; diff --git a/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java new file mode 100644 index 00000000..3c3e4ea3 --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/AbstractModelLoader.java @@ -0,0 +1,172 @@ +package org.beehive.gpullama3.model.loader; + +import org.beehive.gpullama3.core.model.GGUF; +import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.model.Model; +import org.beehive.gpullama3.tokenizer.impl.Tokenizer; +import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.util.Map; + +/** + * Abstract base class for model loaders using Template Method pattern. Provides common loading flow with extension points for model-specific logic. + * + * @param + * The specific Model type to load + * @param + * The specific Configuration type for the model + */ +public abstract class AbstractModelLoader { + + protected final FileChannel fileChannel; + protected final GGUF gguf; + protected final int contextLength; + protected final boolean loadWeights; + protected final boolean useTornadovm; + + protected Vocabulary vocabulary; + + protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { + this.fileChannel = fileChannel; + this.gguf = gguf; + this.contextLength = contextLength; + this.loadWeights = loadWeights; + this.useTornadovm = useTornadovm; + } + + /** + * Template method that defines the model loading workflow. Subclasses should not override this method. + * + * @return The loaded model instance + */ + public final M loadModel() { + try { + Map metadata = gguf.getMetadata(); + + // Step 1: Load vocabulary + this.vocabulary = loadVocabulary(metadata); + + // Step 2: Create tokenizer + Tokenizer tokenizer = createTokenizer(metadata, vocabulary); + + // Step 3: Create configuration + C config = createConfiguration(metadata); + + // Step 4: Load weights (if requested) + Weights weights = null; + if (loadWeights) { + Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); + weights = loadWeights(tensorEntries, config); + } + + // Step 5: Create and return model instance + return createModel(config, tokenizer, weights); + + } catch (IOException e) { + throw new ModelLoadException("Failed to load model", e); + } + } + + /** + * Load the vocabulary from GGUF metadata. Model-specific implementations should override this method. + * + * @param metadata + * The GGUF metadata map + * @return The loaded Vocabulary + */ + protected abstract Vocabulary loadVocabulary(Map metadata); + + /** + * Create a tokenizer instance for this model. + * + * @param metadata + * The GGUF metadata map + * @param vocabulary + * The loaded vocabulary + * @return The tokenizer instance + */ + protected abstract Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary); + + /** + * Create a configuration instance from GGUF metadata. + * + * @param metadata + * The GGUF metadata map + * @return The configuration instance + */ + protected abstract C createConfiguration(Map metadata); + + /** + * Load model weights from tensor entries. Default implementation handles common weight loading logic. + * + * @param tensorEntries + * Map of tensor names to tensor entries + * @param config + * The model configuration + * @return The loaded weights + */ + public Weights loadWeights(Map tensorEntries, C config) { + // Precompute RoPE frequencies + Pair ropeFreqs = precomputeRopeFrequencies(config); + + // Get token embeddings and output weights + GGMLTensorEntry tokenEmbeddings = getTokenEmbeddings(tensorEntries); + GGMLTensorEntry outputWeight = getOutputWeight(tensorEntries, tokenEmbeddings); + + // Delegate to specific implementation + if (useTornadovm) { + return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + } else { + return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + } + } + + /** + * Create the final model instance. + * + * @param config + * The model configuration + * @param tokenizer + * The tokenizer + * @param weights + * The loaded weights + * @return The model instance + */ + protected abstract M createModel(C config, Tokenizer tokenizer, Weights weights); + + /** + * Precompute RoPE frequencies for this model. Default implementation can be overridden for custom RoPE configurations. + */ + protected abstract Pair precomputeRopeFrequencies(C config); + + /** + * Get token embeddings tensor entry. Default implementation can be overridden for different tensor naming. + */ + protected GGMLTensorEntry getTokenEmbeddings(Map tensorEntries) { + return tensorEntries.get("token_embd.weight"); + } + + /** + * Get output weight tensor entry. Default implementation falls back to token embeddings if output.weight not found. + */ + protected GGMLTensorEntry getOutputWeight(Map tensorEntries, GGMLTensorEntry tokenEmbeddings) { + return tensorEntries.getOrDefault("output.weight", tokenEmbeddings); + } + + /** + * Create standard (CPU) weights. + */ + protected abstract Weights createStandardWeights(Map tensorEntries, C config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight); + + /** + * Create TornadoVM (GPU) weights. + */ + protected abstract Weights createTornadoVMWeights(Map tensorEntries, C config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight); +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java index 79f35c92..1b55184f 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/LlamaModelLoader.java @@ -1,60 +1,149 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.auxiliary.Timer; +import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.GGUF; +import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.fp16.LlamaTornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.llama.Llama; import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.tokenizer.impl.LlamaTokenizer; import org.beehive.gpullama3.tokenizer.impl.Tokenizer; import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; -public class LlamaModelLoader extends ModelLoader { +import static org.beehive.gpullama3.model.loader.ModelLoader.*; - public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadoVM) { - super(fileChannel, gguf, contextLength, loadWeights, useTornadoVM); +public class LlamaModelLoader extends AbstractModelLoader { + + public LlamaModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { + super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); + } + + @Override + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadLlamaVocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + return new LlamaTokenizer(metadata, vocabulary); + } + + @Override + protected LlamaConfiguration createConfiguration(Map metadata) { + int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length"); + + return new LlamaConfiguration( + (int) metadata.get("llama.embedding_length"), + (int) metadata.get("llama.feed_forward_length"), + (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + metadata.containsKey("llama.attention.head_count_kv") ? + (int) metadata.get("llama.attention.head_count_kv") + : (int) metadata.get("llama.attention.head_count"), + vocabSize, + (int) metadata.get("llama.context_length"), + (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)).withContextLength(contextLength); + } + + @Override + protected Pair precomputeRopeFrequencies(LlamaConfiguration config) { + return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength() + ); + } + + @Override + protected Llama createModel(LlamaConfiguration config, Tokenizer tokenizer, Weights weights) { + return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); } - // @formatter:off @Override - public Llama loadModel() { - try { - Map metadata = gguf.getMetadata(); - - Vocabulary vocabulary = Vocabulary.loadLlamaVocabulary(metadata); - Tokenizer tokenizer = new LlamaTokenizer(metadata, vocabulary); - - LlamaConfiguration config = new LlamaConfiguration( - (int) metadata.get("llama.embedding_length"), - (int) metadata.get("llama.feed_forward_length"), - (int) metadata.get("llama.block_count"), - (int) metadata.get("llama.attention.head_count"), - - metadata.containsKey("llama.attention.head_count_kv") ? - (int) metadata.get("llama.attention.head_count_kv") : - (int) metadata.get("llama.attention.head_count"), - - vocabulary.size(), - (int) metadata.get("llama.context_length"), - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) - ).withContextLength(contextLength); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - return new Llama(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); - } catch (IOException e) { - throw new RuntimeException(e); + protected Weights createStandardWeights(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + + return new LlamaStandardWeights( + loadQuantized(tokenEmbeddings), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadQuantized(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadQuantized(outputWeight), + outputWeight.ggmlType()); + } + + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); } + + GGMLType ggmlType = outputWeight.ggmlType(); + return switch(ggmlType) { + case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + }; + } + + private Weights createTornadoVMWeightsF16(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + return new LlamaTornadoWeights( + loadTensorAsFloatArray(tokenEmbeddings), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), + FloatArray.fromArray(ropeFreqs.first()), + FloatArray.fromArray(ropeFreqs.second()), + loadTensorAsHalfFloatArray(outputWeight), + outputWeight.ggmlType() + ); + } + + private Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, LlamaConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + return new Q8_0Weights( + loadTensorAsFloatArray(tokenEmbeddings), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), + FloatArray.fromArray(ropeFreqs.first()), + FloatArray.fromArray(ropeFreqs.second()), + loadQ8_0QuantizedTensor(outputWeight), + outputWeight.ggmlType() + ); } - // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java index efe64234..189db23f 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/MistralModelLoader.java @@ -1,66 +1,152 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.auxiliary.Timer; +import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.GGUF; +import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; +import org.beehive.gpullama3.core.types.Pair; +import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; +import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights; +import org.beehive.gpullama3.inference.weights.tornado.fp16.LlamaTornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights; import org.beehive.gpullama3.model.format.ChatFormat; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.model.mistral.Mistral; import org.beehive.gpullama3.model.mistral.MistralConfiguration; import org.beehive.gpullama3.tokenizer.impl.MistralTokenizer; import org.beehive.gpullama3.tokenizer.impl.Tokenizer; import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; +import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; +import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; -public class MistralModelLoader extends ModelLoader { +import static org.beehive.gpullama3.model.loader.ModelLoader.*; +import static org.beehive.gpullama3.model.loader.ModelLoader.floatBufferToFloatArray; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsFloatArrayFromBuffer; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsQ8_0QuantizedTensor; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadQ8_0QuantizedTensor; + +public class MistralModelLoader extends AbstractModelLoader { public MistralModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); } - // @formatter:off @Override - public Mistral loadModel() { - try { - Map metadata = gguf.getMetadata(); - - Vocabulary vocabulary = Vocabulary.loadMistralVocabulary(metadata); - Tokenizer tokenizer = new MistralTokenizer(metadata, vocabulary); - - int modelContextLength = (int) metadata.get("llama.context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - MistralConfiguration config = new MistralConfiguration( - (int) metadata.get("llama.embedding_length"), - (int) metadata.get("llama.feed_forward_length"), - (int) metadata.get("llama.block_count"), - (int) metadata.get("llama.attention.head_count"), - - metadata.containsKey("llama.attention.head_count_kv") - ? (int) metadata.get("llama.attention.head_count_kv") - : (int) metadata.get("llama.attention.head_count"), - - vocabulary.size(), - contextLength, - false, - (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), - (float) metadata.getOrDefault("llama.rope.freq_base", 10000f) - ); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); - } catch (IOException e) { - throw new RuntimeException(e); + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadMistralVocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + return new MistralTokenizer(metadata, vocabulary); + } + + @Override + protected MistralConfiguration createConfiguration(Map metadata) { + int modelContextLength = (int) metadata.get("llama.context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + + // Get vocabulary size from metadata + int vocabSize = metadata.containsKey("llama.vocab_size") ? (int) metadata.get("llama.vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length"); + + return new MistralConfiguration((int) metadata.get("llama.embedding_length"), (int) metadata.get("llama.feed_forward_length"), (int) metadata.get("llama.block_count"), + (int) metadata.get("llama.attention.head_count"), + + metadata.containsKey("llama.attention.head_count_kv") ? (int) metadata.get("llama.attention.head_count_kv") : (int) metadata.get("llama.attention.head_count"), + + vocabSize, finalContextLength, false, (float) metadata.getOrDefault("llama.attention.layer_norm_rms_epsilon", 1e-5f), + (float) metadata.getOrDefault("llama.rope.freq_base", 10000f)); + } + + @Override + protected Pair precomputeRopeFrequencies(MistralConfiguration config) { + return RoPE.precomputeFreqsCis(config.contextLength(), config.dim() / config.numberOfHeads(), config.ropeTheta(), false, 1.0f, 1.0f, 1.0f, config.contextLength() + ); + } + + @Override + protected Mistral createModel(MistralConfiguration config, Tokenizer tokenizer, Weights weights) { + return new Mistral(config, tokenizer, weights, ChatFormat.create(tokenizer, null)); + } + + @Override + protected Weights createStandardWeights(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + + return new LlamaStandardWeights( + loadQuantized(tokenEmbeddings), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + loadQuantized(tensorEntries.get("output_norm.weight")), + new ArrayFloatTensor(ropeFreqs.first()), + new ArrayFloatTensor(ropeFreqs.second()), + loadQuantized(outputWeight), + outputWeight.ggmlType()); + } + + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); } + + GGMLType ggmlType = outputWeight.ggmlType(); + return switch(ggmlType) { + case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + }; + } + + private Weights createTornadoVMWeightsF16(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + + return new LlamaTornadoWeights(ModelLoader.loadTensorAsFloatArray(tokenEmbeddings), + ModelLoader.loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + ModelLoader.loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + ModelLoader.loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + ModelLoader.floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), + FloatArray.fromArray(ropeFreqs.first()), + FloatArray.fromArray(ropeFreqs.second()), + ModelLoader.loadTensorAsHalfFloatArray(outputWeight), + outputWeight.ggmlType()); + } + + private Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, MistralConfiguration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + return new Q8_0Weights( + loadTensorAsFloatArray(tokenEmbeddings), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), + loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), + floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), + FloatArray.fromArray(ropeFreqs.first()), + FloatArray.fromArray(ropeFreqs.second()), + loadQ8_0QuantizedTensor(outputWeight), + outputWeight.ggmlType() + ); } - // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java new file mode 100644 index 00000000..f09ec56b --- /dev/null +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoadException.java @@ -0,0 +1,15 @@ +package org.beehive.gpullama3.model.loader; + +/** + * Exception thrown when model loading fails. + */ +public class ModelLoadException extends RuntimeException { + + public ModelLoadException(String message) { + super(message); + } + + public ModelLoadException(String message, Throwable cause) { + super(message, cause); + } +} \ No newline at end of file diff --git a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java index 7d0b8dff..6b3c88eb 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/ModelLoader.java @@ -1,32 +1,13 @@ package org.beehive.gpullama3.model.loader; import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.aot.AOT; import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.GGUF; -import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; -import org.beehive.gpullama3.core.model.tensor.F16FloatTensor; -import org.beehive.gpullama3.core.model.tensor.F32FloatTensor; -import org.beehive.gpullama3.core.model.tensor.FloatTensor; -import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; -import org.beehive.gpullama3.core.model.tensor.Q4_0FloatTensor; -import org.beehive.gpullama3.core.model.tensor.Q8_0FloatTensor; -import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor; -import org.beehive.gpullama3.core.types.Pair; -import org.beehive.gpullama3.inference.operation.RoPE; -import org.beehive.gpullama3.inference.weights.Weights; -import org.beehive.gpullama3.inference.weights.standard.LlamaStandardWeights; -import org.beehive.gpullama3.inference.weights.tornado.LlamaTornadoWeights; -import org.beehive.gpullama3.inference.weights.tornado.Q8_0Weights; -import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.core.model.tensor.*; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.ModelType; -import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.HalfFloat; -import uk.ac.manchester.tornado.api.types.arrays.ByteArray; -import uk.ac.manchester.tornado.api.types.arrays.FloatArray; -import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -import uk.ac.manchester.tornado.api.types.arrays.Int8Array; +import uk.ac.manchester.tornado.api.types.arrays.*; import java.io.IOException; import java.lang.foreign.MemorySegment; @@ -39,9 +20,10 @@ import java.util.Map; import java.util.function.IntFunction; -public abstract class ModelLoader { +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsHalfFloatArray; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensorAsFloatArray; - public static final boolean USE_AOT = Boolean.parseBoolean(System.getProperty("llama.AOT", "false")); // Use Ahead-of-Time compilation +public abstract class ModelLoader { protected FileChannel fileChannel; protected GGUF gguf; @@ -99,13 +81,6 @@ private static ModelType detectModelType(Map metadata) { * if AOT loading is enabled but the preloaded model is unavailable */ public static Model loadModel(Options options) throws IOException { - if (USE_AOT) { - Model model = AOT.tryUsePreLoaded(options.modelPath(), options.maxTokens()); - if (model == null) { - throw new IllegalStateException("Failed to load precompiled AOT model."); - } - return model; - } return ModelLoader.loadModel(options.modelPath(), options.maxTokens(), true, options.useTornadovm()); } @@ -119,6 +94,10 @@ public static Model loadModel(Path ggufPath, int contextLength, boolean loadWeig return modelType.loadModel(fileChannel, gguf, contextLength, loadWeights, useTornadovm); } + /** + * Dispatcher method for loading a standard (non-tornado) tensor based on type. + * Used in CPU-path. + */ public static FloatTensor loadQuantized(GGMLTensorEntry entry) { GGMLType ggmlType = entry.ggmlType(); return switch (ggmlType) { @@ -130,6 +109,55 @@ public static FloatTensor loadQuantized(GGMLTensorEntry entry) { }; } + /** + * Dispatcher method for loading a standard tensor array based on type. + * Used in CPU-path. + */ + public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction getTensorEntry) { + FloatTensor[] array = new FloatTensor[size]; + for (int i = 0; i < size; i++) { + array[i] = loadQuantized(getTensorEntry.apply(i)); + } + return array; + } + + /** + * [WIP] + * Dispatcher method for loading a TornadoVM tensor based on type. + * Used in GPU-path. + * + * TODO: fix this to follow loadQuantized logic + */ + public static FloatTensor loadTornadoTensor(GGMLTensorEntry entry) { + GGMLType ggmlType = entry.ggmlType(); + int size = FloatTensor.numberOfElements(entry.shape()); + return switch (ggmlType) { +// case F32 -> new F32QuantizedTensor(size, entry.memorySegment()); + case Q8_0 -> loadQ8_0QuantizedTensor(entry); +// case Q4_0 -> throw new UnsupportedOperationException("Not yet implemented"); +// //FloatTensor.numberOfElements(entry.shape()), entry.memorySegment() +// case F16 -> new F16QuantizedTensor(size, entry.memorySegment()); +// /*{ +// HalfFloatArray array = new HalfFloatArray(); +// array.getSegment().copyFrom(entry.memorySegment()); +// // or array.getSegmentWithHeader() ? +// }*/ + default -> throw new UnsupportedOperationException("Quantization format " + ggmlType); + }; + } + + /** + * Dispatcher method for loading a TornadoVM tensor array based on type. + * Used in GPU-path. + */ + public static FloatTensor[] loadTornadoTensorArray(int size, IntFunction getTensorEntry) { + FloatTensor[] array = new FloatTensor[size]; + for (int i = 0; i < size; i++) { + array[i] = loadTornadoTensor(getTensorEntry.apply(i)); + } + return array; + } + public static FloatArray[] loadArrayAsFloatArray(int size, IntFunction getTensorEntry) { FloatArray[] array = new FloatArray[size]; for (int i = 0; i < size; i++) { @@ -216,6 +244,8 @@ public static HalfFloatArray loadTensorAsHalfFloatArray(GGMLTensorEntry entry) { } } + // TODO: rename to loadQ8_0Tensor + // move to a utils class public static Q8_0QuantizedTensor loadQ8_0QuantizedTensor(GGMLTensorEntry entry) { if (entry.ggmlType() != GGMLType.Q8_0) { throw new IllegalArgumentException("Expected Q8_0 tensor, got: " + entry.ggmlType() + " for tensor: " + entry.name()); @@ -240,6 +270,7 @@ public static Q8_0QuantizedTensor loadQ8_0QuantizedTensor(GGMLTensorEntry entry) ValueLayout.OfByte byteLayout = ValueLayout.JAVA_BYTE; for (int block = 0; block < numBlocks; block++) { + // TODO: use GGML type method for the 34L size long blockOffset = block * 34L; // 34 bytes per block // read fp16 scale (first 2 bytes of block) @@ -247,6 +278,7 @@ public static Q8_0QuantizedTensor loadQ8_0QuantizedTensor(GGMLTensorEntry entry) scales.set(block, new HalfFloat(scaleRaw)); // read 32 int8 quantized values (remaining bytes of block) + // TODO: use GGML type method for the 32 size for (int i = 0; i < 32; i++) { byte quantValue = q8Segment.get(byteLayout, blockOffset + 2 + i); quants.set(block * 32 + i, quantValue); @@ -256,14 +288,6 @@ public static Q8_0QuantizedTensor loadQ8_0QuantizedTensor(GGMLTensorEntry entry) return new Q8_0QuantizedTensor(size, scales, quants, q8Segment); } - public static FloatTensor[] loadArrayOfQuantized(int size, IntFunction getTensorEntry) { - FloatTensor[] array = new FloatTensor[size]; - for (int i = 0; i < size; i++) { - array[i] = loadQuantized(getTensorEntry.apply(i)); - } - return array; - } - public static FloatBuffer[] loadArrayOfFloatBuffer(int size, IntFunction getTensorEntry) { FloatBuffer[] array = new FloatBuffer[size]; for (int i = 0; i < size; i++) { @@ -282,104 +306,6 @@ public static FloatBuffer toFloatBuffer(GGMLTensorEntry tensorEntry) { public abstract Model loadModel(); - //@formatter:off - public Weights loadWeights(Map tensorEntries, Configuration config) { - boolean ropeScaling = tensorEntries.containsKey("rope_freqs"); - RopeConfig ropeConfig = new RopeConfig(8.0f, // scaleFactor - 1.0f, // loFreqFactor - 3.0f, // hiFreqFactor - 8192 // oldContextLength - ); - - Pair ropeFreqs = RoPE.precomputeFreqsCis( - config.contextLength(), // Maximum sequence length the model can process - config.headSize(), // Dimension of each attention head - config.ropeTheta(), // Base frequency parameter (typically 10000.0) - ropeScaling, // Whether to apply frequency scaling (determined by model type) - ropeConfig.scaleFactor, // Scale factor for extending context length (NTK-aware scaling) - ropeConfig.loFreqFactor, // Low frequency scaling factor for better long-range dependencies - ropeConfig.hiFreqFactor, // High frequency scaling factor for preserving local precision - ropeConfig.oldContextLength // Original context length the model was trained with - ); - - GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); - GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); - - if (useTornadovm) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")"); - } - - if (outputWeight.ggmlType() == GGMLType.Q8_0) { - return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } else { - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } - } else { - return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } - } - - public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - return new LlamaTornadoWeights( - // Load directly to TornadoVM format - loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), - FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType()) { - }; - } - - private Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - return new Q8_0Weights( - loadTensorAsFloatArray(tokenEmbeddings), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), - floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), - FloatArray.fromArray(ropeFreqs.first()), - FloatArray.fromArray(ropeFreqs.second()), - loadQ8_0QuantizedTensor(outputWeight), - outputWeight.ggmlType() - ); - } - - /** - * Creates weights in standard format only - */ - public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - return new LlamaStandardWeights( - loadQuantized(tokenEmbeddings), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), - loadQuantized(tensorEntries.get("output_norm.weight")), - new ArrayFloatTensor(ropeFreqs.first()), - new ArrayFloatTensor(ropeFreqs.second()), - loadQuantized(outputWeight), - outputWeight.ggmlType()); - } - // Helper class to encapsulate RoPE configuration parameters private static class RopeConfig { final float scaleFactor; diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java index 14b0dab7..9f118dfa 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Phi3ModelLoader.java @@ -1,18 +1,17 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.LlamaApp; -import org.beehive.gpullama3.Options; -import org.beehive.gpullama3.auxiliary.Timer; import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.GGUF; import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; +import org.beehive.gpullama3.core.model.tensor.FloatTensor; import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; import org.beehive.gpullama3.core.types.Pair; import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.standard.Phi3StandardWeights; -import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; -import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeightsQ8_0; +import org.beehive.gpullama3.inference.weights.tornado.fp16.Phi3TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Phi3TornadoWeightsQ8_0; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.phi3.Phi3; @@ -22,164 +21,185 @@ import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; -public class Phi3ModelLoader extends ModelLoader { +import static org.beehive.gpullama3.model.loader.ModelLoader.*; +import static org.beehive.gpullama3.model.loader.ModelLoader.floatBufferToFloatArray; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsFloatArrayFromBuffer; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsHalfFloatArray; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadArrayAsQ8_0QuantizedTensor; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadQ8_0QuantizedTensor; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensorAsFloatArray; +import static org.beehive.gpullama3.model.loader.ModelLoader.loadTensorAsHalfFloatArray; + +import static org.beehive.gpullama3.model.loader.ModelLoader.*; + +public class Phi3ModelLoader extends AbstractModelLoader { + private int modelContextLength; + public Phi3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); } - // @formatter:off @Override - public Phi3 loadModel() { - try { - Map metadata = gguf.getMetadata(); - final String modelPrefix = "phi3."; + protected Vocabulary loadVocabulary(Map metadata) { + return Vocabulary.loadPhi3Vocabulary(metadata); + } - Vocabulary vocabulary = Vocabulary.loadPhi3Vocabulary(metadata); + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { Tokenizer tokenizer = new Phi3Tokenizer(metadata, vocabulary); - - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Tokenizer: " + tokenizer.getClass().getSimpleName()); - } - - int modelContextLength = (int) metadata.get(modelPrefix + "context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - Phi3Configuration config = new Phi3Configuration( - (int) metadata.get(modelPrefix + "embedding_length"), // dim - (int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim - (int) metadata.get(modelPrefix + "block_count"), // n_layers - (int) metadata.get(modelPrefix + "attention.head_count"), // n_heads - - metadata.containsKey(modelPrefix + "attention.head_count_kv") - ? (int) metadata.get(modelPrefix + "attention.head_count_kv") - : (int) metadata.get(modelPrefix + "attention.head_count"), // n_kv_heads - - vocabulary.size(), // vocab_size - contextLength, // context_length (user-specified, not model) - (float) metadata.getOrDefault(modelPrefix + "attention.layer_norm_rms_epsilon", 1e-5f), // rms_norm_eps - (float) metadata.getOrDefault(modelPrefix + "rope.freq_base", 10000f) // rope_theta - ); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config, modelContextLength); - } - - // Phi3 chat tokens - ChatFormat.ChatTokens chatTokens = new ChatFormat.ChatTokens( - "<|system|>", "<|end|>", "<|user|>", "<|end|>", "<|assistant|>" - ); - - return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); - } catch (IOException e) { - throw new RuntimeException(e); + System.out.println("Tokenizer: " + tokenizer.getClass().getSimpleName()); + return tokenizer; } + return new Phi3Tokenizer(metadata, vocabulary); } - // @formatter:on - // @formatter:off - private Weights loadWeights(Map tensorEntries, Configuration config, int modelContextLength) { + @Override + protected Phi3Configuration createConfiguration(Map metadata) { + final String modelPrefix = "phi3."; + modelContextLength = (int) metadata.get(modelPrefix + "context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + + int vocabSize = metadata.containsKey(modelPrefix + "vocab_size") ? (int) metadata.get(modelPrefix + "vocab_size") : (int) metadata.get("tokenizer.ggml.tokens.length"); + + return new Phi3Configuration((int) metadata.get(modelPrefix + "embedding_length"), // dim + (int) metadata.get(modelPrefix + "feed_forward_length"), // hidden_dim + (int) metadata.get(modelPrefix + "block_count"), // n_layers + (int) metadata.get(modelPrefix + "attention.head_count"), // n_heads + + metadata.containsKey(modelPrefix + "attention.head_count_kv") ? (int) metadata.get(modelPrefix + "attention.head_count_kv") : (int) metadata.get(modelPrefix + "attention.head_count"), // n_kv_heads + + vocabSize, // vocab_size + finalContextLength, // context_length (user-specified, not model) + (float) metadata.getOrDefault(modelPrefix + "attention.layer_norm_rms_epsilon", 1e-5f), // rms_norm_eps + (float) metadata.getOrDefault(modelPrefix + "rope.freq_base", 10000f) // rope_theta + ); + } + + @Override + protected Pair precomputeRopeFrequencies(Phi3Configuration config) { // Calculate head size from dim and numberOfHeads int headSize = config.dim() / config.numberOfHeads(); - Pair ropeFreqs = RoPE.precomputeFreqsCis( - modelContextLength, // Use model context length for RoPE precomputation + return RoPE.precomputeFreqsCis(modelContextLength, // Use model context length for RoPE precomputation headSize, // Calculated head size - config.ropeTheta(), - false, // Phi3 uses standard RoPE, not neox-style based on reference + config.ropeTheta(), false, // Phi3 uses standard RoPE, not neox-style based on reference 8, 1, 3, 8192 // Additional RoPE parameters from reference ); + } - GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); - GGMLTensorEntry outputWeight = tensorEntries.get("output.weight"); // Phi3 always has separate output weight - - if (useTornadovm) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")"); - } - if (outputWeight.ggmlType() == GGMLType.Q8_0) { - return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } else { - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } - } else { - return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } + @Override + protected Phi3 createModel(Phi3Configuration config, Tokenizer tokenizer, Weights weights) { + // Phi3 chat tokens + ChatFormat.ChatTokens chatTokens = new ChatFormat.ChatTokens("<|system|>", "<|end|>", "<|user|>", "<|end|>", "<|assistant|>"); + + return new Phi3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); + } + + @Override + protected Weights createStandardWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + float[] ropeFreqsReal = ropeFreqs.first(); + float[] ropeFreqsImag = ropeFreqs.second(); + + return new Phi3StandardWeights( + loadQuantized(tokenEmbeddings), // token_embedding_table + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[]) + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined) + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[]) + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined) + loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor) + new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real + new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag + loadQuantized(outputWeight), // wcls + outputWeight.ggmlType() // weightType + ); } - // @formatter:on - // @formatter:off - public Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config, - Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { - return new Phi3TornadoWeightsQ8_0( + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); + } + + GGMLType ggmlType = outputWeight.ggmlType(); + return switch(ggmlType) { + case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + }; + } + + private Weights createTornadoVMWeightsF16(Map tensorEntries, Phi3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + return new Phi3TornadoWeights( loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown - loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference) + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown + loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference) floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), - loadQ8_0QuantizedTensor(outputWeight), + loadTensorAsHalfFloatArray(outputWeight), outputWeight.ggmlType() ); } - public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, - Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - return new Phi3TornadoWeights( + public Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config, + Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + return new Phi3TornadoWeightsQ8_0( loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // Combined QKV + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference) + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown + loadArrayAsQ8_0QuantizedTensor(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (not combined in reference) floatBufferToFloatArray(tensorEntries.get("output_norm.weight")), FloatArray.fromArray(ropeFreqs.first()), FloatArray.fromArray(ropeFreqs.second()), - loadTensorAsHalfFloatArray(outputWeight), + loadQ8_0QuantizedTensor(outputWeight), outputWeight.ggmlType() ); } - // @formatter:on - // @formatter:off - @Override - public Weights createStandardWeights(Map tensorEntries, - Configuration config, - Pair ropeFreqs, - GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - float[] ropeFreqsReal = ropeFreqs.first(); - float[] ropeFreqsImag = ropeFreqs.second(); + // Helper methods + private FloatTensor[] loadLayerWeights(Map tensorEntries, Phi3Configuration config, String layerName, String suffix) { + FloatTensor[] weights = new FloatTensor[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.loadQuantized(tensorEntries.get(key)); + } + return weights; + } - return new Phi3StandardWeights( - loadQuantized(tokenEmbeddings), // token_embedding_table - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight (as FloatTensor[]) - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_qkv.weight")), // wqkv (combined) - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), // rms_ffn_weight (as FloatTensor[]) - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // wDown - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // wUp (separate, not combined) - loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight (as FloatTensor) - new ArrayFloatTensor(ropeFreqsReal), // freq_cis_real - new ArrayFloatTensor(ropeFreqsImag), // freq_cis_imag - loadQuantized(outputWeight), // wcls - outputWeight.ggmlType() // weightType - ); + private FloatArray[] loadLayerWeightsAsFloatArraysFromBuffer(Map tensorEntries, Phi3Configuration config, String layerName, String suffix) { + FloatArray[] weights = new FloatArray[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.floatBufferToFloatArray(tensorEntries.get(key)); + } + return weights; + } + + private HalfFloatArray[] loadLayerWeightsAsHalfFloatArrays(Map tensorEntries, Phi3Configuration config, String layerName, String suffix) { + HalfFloatArray[] weights = new HalfFloatArray[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.loadTensorAsHalfFloatArray(tensorEntries.get(key)); + } + return weights; } - // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java index fef3eb9d..bd31066e 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen2ModelLoader.java @@ -1,18 +1,18 @@ package org.beehive.gpullama3.model.loader; -import org.beehive.gpullama3.Options; import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.GGUF; import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; +import org.beehive.gpullama3.core.model.tensor.FloatTensor; import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; import org.beehive.gpullama3.core.types.Pair; import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.standard.Qwen2StandardWeights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeightsQ8_0; +import org.beehive.gpullama3.inference.weights.tornado.fp16.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen2TornadoWeightsQ8_0; import org.beehive.gpullama3.model.Configuration; -import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens; import org.beehive.gpullama3.model.qwen2.Qwen2; @@ -22,97 +22,76 @@ import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; +import static org.beehive.gpullama3.core.model.GGMLType.F16; +import static org.beehive.gpullama3.model.loader.ModelLoader.*; import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary; -public class Qwen2ModelLoader extends ModelLoader { +public class Qwen2ModelLoader extends AbstractModelLoader { public Qwen2ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); } @Override - public Model loadModel() { - Map metadata = gguf.getMetadata(); - String basename = (String) metadata.get("general.basename"); - - String modelName = "DeepSeek-R1-Distill-Qwen".equals(basename) ? "DeepSeek-R1-Distill-Qwen" : "Qwen2.5"; - - try { - // reuse method of Qwen3 - Vocabulary vocabulary = loadQwen3Vocabulary(metadata); - boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); - Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); - - int modelContextLength = (int) metadata.get("qwen2.context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count"); - Qwen2Configuration config = new Qwen2Configuration((int) metadata.get("qwen2.embedding_length"), // dim - (int) metadata.get("qwen2.feed_forward_length"), // hiddendim - (int) metadata.get("qwen2.block_count"), // numberOfLayers - (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads - - numberOfKeyValueHeads, // numberOfKeyValueHeads - numberOfKeyValueHeads, // numberOfHeadsKey - numberOfKeyValueHeads, // numberOfHeadsValue - - vocabulary.size(), modelContextLength, contextLength, false, (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"), (float) metadata.get("qwen2.rope.freq_base")); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - // Qwen2.5-Coder uses <|endoftext|> as stop-token. - ChatTokens chatTokens = isDeepSeekR1DistillQwen - ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") - : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); - return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); - } catch (IOException e) { - throw new RuntimeException(e); - } + protected Vocabulary loadVocabulary(Map metadata) { + return loadQwen3Vocabulary(metadata); + } + + @Override + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); } - // @formatter:off @Override - public Weights loadWeights(Map tensorEntries, Configuration config) { - Pair ropeFreqs = RoPE.precomputeFreqsCis( - config.contextLengthModel(), - config.headSize(), - config.ropeTheta(), + protected Qwen2Configuration createConfiguration(Map metadata) { + int modelContextLength = (int) metadata.get("qwen2.context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + + int numberOfKeyValueHeads = metadata.containsKey("qwen2.attention.head_count_kv") ? (int) metadata.get("qwen2.attention.head_count_kv") : (int) metadata.get("qwen2.attention.head_count"); + int vocabSize = vocabulary.size(); + + return new Qwen2Configuration( + (int) metadata.get("qwen2.embedding_length"), // dim + (int) metadata.get("qwen2.feed_forward_length"), // hiddendim + (int) metadata.get("qwen2.block_count"), // numberOfLayers + (int) metadata.get("qwen2.attention.head_count"), // numberOfHeads + + numberOfKeyValueHeads, // numberOfKeyValueHeads + numberOfKeyValueHeads, // numberOfHeadsKey + numberOfKeyValueHeads, // numberOfHeadsValue + + vocabSize, + modelContextLength, + finalContextLength, false, - 8, - 1, - 3, - 8192 + (float) metadata.get("qwen2.attention.layer_norm_rms_epsilon"), + (float) metadata.get("qwen2.rope.freq_base") ); + } - GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); - GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); - - if (useTornadovm) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")"); - } - if (outputWeight.ggmlType() == GGMLType.Q8_0) { - return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } else { - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } - } else { - return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } + @Override + protected Pair precomputeRopeFrequencies(Qwen2Configuration config) { + return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.headSize(), config.ropeTheta(), false, 8, 1, 3, 8192); } @Override - public Weights createStandardWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + protected Qwen2 createModel(Qwen2Configuration config, Tokenizer tokenizer, Weights weights) { + Map metadata = gguf.getMetadata(); + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + // Qwen2.5-Coder uses <|endoftext|> as stop-token. + ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") + : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); + return new Qwen2(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); + } + + @Override + protected Weights createStandardWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { return new Qwen2StandardWeights( loadQuantized(tokenEmbeddings), @@ -120,11 +99,9 @@ public Weights createStandardWeights(Map tensorEntries, loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), @@ -134,12 +111,28 @@ public Weights createStandardWeights(Map tensorEntries, new ArrayFloatTensor(ropeFreqs.first()), new ArrayFloatTensor(ropeFreqs.second()), loadQuantized(outputWeight), - outputWeight.ggmlType()); + outputWeight.ggmlType() + ); } @Override - public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + protected Weights createTornadoVMWeights(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + F16 + ")"); + } + + GGMLType ggmlType = outputWeight.ggmlType(); + return switch(ggmlType) { + case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + default -> + throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + }; + } + + private Weights createTornadoVMWeightsF16(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { return new Qwen2TornadoWeights( loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), @@ -150,7 +143,6 @@ public Weights createTornadoVMWeights(Map tensorEntries loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.bias")), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.bias")), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.bias")), - loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), loadArrayAsHalfFloatArray(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 @@ -164,8 +156,8 @@ public Weights createTornadoVMWeights(Map tensorEntries ); } - public Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { + public Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Qwen2Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { return new Qwen2TornadoWeightsQ8_0( loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), @@ -189,6 +181,32 @@ public Weights createTornadoVMWeightsQ8_0(Map tensorEnt outputWeight.ggmlType() ); } - // @formatter:on + // Helper methods + private FloatTensor[] loadLayerWeights(Map tensorEntries, Qwen2Configuration config, String layerName, String suffix) { + FloatTensor[] weights = new FloatTensor[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.loadQuantized(tensorEntries.get(key)); + } + return weights; + } + + private FloatArray[] loadLayerWeightsAsFloatArraysFromBuffer(Map tensorEntries, Qwen2Configuration config, String layerName, String suffix) { + FloatArray[] weights = new FloatArray[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.floatBufferToFloatArray(tensorEntries.get(key)); + } + return weights; + } + + private HalfFloatArray[] loadLayerWeightsAsHalfFloatArrays(Map tensorEntries, Qwen2Configuration config, String layerName, String suffix) { + HalfFloatArray[] weights = new HalfFloatArray[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.loadTensorAsHalfFloatArray(tensorEntries.get(key)); + } + return weights; + } } diff --git a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java index 682c7477..597292cd 100644 --- a/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java +++ b/src/main/java/org/beehive/gpullama3/model/loader/Qwen3ModelLoader.java @@ -5,16 +5,18 @@ import org.beehive.gpullama3.core.model.GGMLType; import org.beehive.gpullama3.core.model.GGUF; import org.beehive.gpullama3.core.model.tensor.ArrayFloatTensor; +import org.beehive.gpullama3.core.model.tensor.FloatTensor; import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry; import org.beehive.gpullama3.core.types.Pair; import org.beehive.gpullama3.inference.operation.RoPE; import org.beehive.gpullama3.inference.weights.Weights; import org.beehive.gpullama3.inference.weights.standard.Qwen3StandardWeights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen3Q8_0TornadoWeights; -import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; -import org.beehive.gpullama3.model.Configuration; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen3Q8_0TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.fp16.Qwen3TornadoWeights; import org.beehive.gpullama3.model.format.ChatFormat; import org.beehive.gpullama3.model.format.ChatFormat.ChatTokens; +import org.beehive.gpullama3.model.llama.LlamaConfiguration; import org.beehive.gpullama3.model.qwen3.Qwen3; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import org.beehive.gpullama3.tokenizer.impl.Qwen3Tokenizer; @@ -22,105 +24,120 @@ import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary; import org.beehive.gpullama3.tornadovm.TornadoVMMasterPlan; import uk.ac.manchester.tornado.api.types.arrays.FloatArray; +import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray; -import java.io.IOException; import java.nio.channels.FileChannel; import java.util.Map; +import static org.beehive.gpullama3.model.loader.ModelLoader.*; import static org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary.loadQwen3Vocabulary; -public class Qwen3ModelLoader extends ModelLoader { +public class Qwen3ModelLoader extends AbstractModelLoader { public Qwen3ModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) { super(fileChannel, gguf, contextLength, loadWeights, useTornadovm); } - // @formatter:off @Override - public Qwen3 loadModel() { - try { - Map metadata = gguf.getMetadata(); - - Vocabulary vocabulary = loadQwen3Vocabulary(metadata); - boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); - Tokenizer tokenizer = new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); - - int modelContextLength = (int) metadata.get("qwen3.context_length"); - if (contextLength < 0 || modelContextLength < contextLength) { - contextLength = modelContextLength; - } - - Qwen3Configuration config = new Qwen3Configuration( - (int) metadata.get("qwen3.embedding_length"), - (int) metadata.get("qwen3.feed_forward_length"), - (int) metadata.get("qwen3.block_count"), - (int) metadata.get("qwen3.attention.head_count"), - - metadata.containsKey("qwen3.attention.head_count_kv") - ? (int) metadata.get("qwen3.attention.head_count_kv") - : (int) metadata.get("qwen3.attention.head_count"), - (int) metadata.get("qwen3.attention.key_length"), - (int) metadata.get("qwen3.attention.value_length"), - - vocabulary.size(), - modelContextLength, contextLength, - false, - (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"), - (float) metadata.get("qwen3.rope.freq_base") - ); - - Weights weights = null; - if (loadWeights) { - Map tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos()); - weights = loadWeights(tensorEntries, config); - } - // Qwen2.5-coder uses <|endoftext|> as stop-token. - ChatTokens chatTokens = isDeepSeekR1DistillQwen ? - new ChatTokens( "<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") : - new ChatTokens( "<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); - return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); - } catch (IOException e) { - throw new RuntimeException(e); - } + protected Vocabulary loadVocabulary(Map metadata) { + return loadQwen3Vocabulary(metadata); } - // @formatter:on - // @formatter:off @Override - public Weights loadWeights(Map tensorEntries, Configuration config) { - Pair ropeFreqs = RoPE.precomputeFreqsCis( - config.contextLengthModel(), - config.numberOfHeadsKey(), - config.ropeTheta(), + protected Tokenizer createTokenizer(Map metadata, Vocabulary vocabulary) { + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + return new Qwen3Tokenizer(metadata, vocabulary, isDeepSeekR1DistillQwen); + } + + @Override + protected Qwen3Configuration createConfiguration(Map metadata) { + int modelContextLength = (int) metadata.get("qwen3.context_length"); + int finalContextLength = (contextLength < 0 || modelContextLength < contextLength) ? modelContextLength : contextLength; + + int vocabSize = vocabulary.size(); + + return new Qwen3Configuration( + (int) metadata.get("qwen3.embedding_length"), + (int) metadata.get("qwen3.feed_forward_length"), + (int) metadata.get("qwen3.block_count"), + (int) metadata.get("qwen3.attention.head_count"), + + metadata.containsKey("qwen3.attention.head_count_kv") ? + (int) metadata.get("qwen3.attention.head_count_kv") : + (int) metadata.get("qwen3.attention.head_count"), + (int) metadata.get("qwen3.attention.key_length"), + (int) metadata.get("qwen3.attention.value_length"), + + vocabSize, + modelContextLength, + finalContextLength, false, - 0, - 0, - 0, - 0 + (float) metadata.get("qwen3.attention.layer_norm_rms_epsilon"), + (float) metadata.get("qwen3.rope.freq_base") ); + } - GGMLTensorEntry tokenEmbeddings = tensorEntries.get("token_embd.weight"); - GGMLTensorEntry outputWeight = tensorEntries.getOrDefault("output.weight", tokenEmbeddings); - - if (useTornadovm) { - if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { - System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + ")"); - } - if (outputWeight.ggmlType() == GGMLType.Q8_0) { - return createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } else { - return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } - } else { - return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); - } + @Override + protected Pair precomputeRopeFrequencies(Qwen3Configuration config) { + return RoPE.precomputeFreqsCis(config.contextLengthModel(), config.numberOfHeadsKey(), config.ropeTheta(), false, 0, 0, 0, 0); + } + + @Override + protected Qwen3 createModel(Qwen3Configuration config, Tokenizer tokenizer, Weights weights) { + Map metadata = gguf.getMetadata(); + boolean isDeepSeekR1DistillQwen = "DeepSeek-R1-Distill-Qwen".equals(metadata.get("general.basename")); + // Qwen2.5-coder uses <|endoftext|> as stop-token. + ChatTokens chatTokens = isDeepSeekR1DistillQwen ? new ChatTokens("<|begin▁of▁sentence|>", "", "", "<|end▁of▁sentence|>", "") + : new ChatTokens("<|im_start|>", "<|im_end|>", "", "<|end_of_text|>", "<|endoftext|>"); + return new Qwen3(config, tokenizer, weights, ChatFormat.create(tokenizer, chatTokens)); } - // @formatter:on - // @formatter:off @Override - public Weights createTornadoVMWeights(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + protected Weights createStandardWeights(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, GGMLTensorEntry outputWeight) { + float[] ropeFreqsReal = ropeFreqs.first(); + float[] ropeFreqsImag = ropeFreqs.second(); + return new Qwen3StandardWeights( + loadQuantized(tokenEmbeddings), + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 + loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 + loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight + new ArrayFloatTensor(ropeFreqsReal), + new ArrayFloatTensor(ropeFreqsImag), + tensorEntries.containsKey("output.weight") + ? ModelLoader.loadQuantized(tensorEntries.get("output.weight")) + : loadQuantized(tokenEmbeddings), // weights are shared + null + ); + } + + @Override + protected Weights createTornadoVMWeights(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { + if (TornadoVMMasterPlan.ENABLE_TORNADOVM_INIT_TIME) { + System.out.println("Loading model weights in TornadoVM format (loading " + outputWeight.ggmlType() + " -> " + GGMLType.F16 + ")"); + } + + GGMLType ggmlType = outputWeight.ggmlType(); + return switch(ggmlType) { + case F16 -> createTornadoVMWeightsF16(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + case Q8_0 -> createTornadoVMWeightsQ8_0(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight); + default -> throw new UnsupportedOperationException("Type: " + ggmlType + " currently not supported for TornadoVM weights."); + }; + + } + + private Weights createTornadoVMWeightsF16(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { return new Qwen3TornadoWeights( loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), @@ -142,8 +159,8 @@ public Weights createTornadoVMWeights(Map tensorEntries ); } - public Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { + private Q8_0Weights createTornadoVMWeightsQ8_0(Map tensorEntries, Qwen3Configuration config, Pair ropeFreqs, GGMLTensorEntry tokenEmbeddings, + GGMLTensorEntry outputWeight) { return new Qwen3Q8_0TornadoWeights( loadTensorAsFloatArray(tokenEmbeddings), loadArrayAsFloatArrayFromBuffer(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), @@ -164,40 +181,32 @@ public Weights createTornadoVMWeightsQ8_0(Map tensorEnt outputWeight.ggmlType() ); } - // @formatter:on - // @formatter:off - @Override - public Weights createStandardWeights(Map tensorEntries, - Configuration config, - Pair ropeFreqs, - GGMLTensorEntry tokenEmbeddings, - GGMLTensorEntry outputWeight) { - float[] ropeFreqsReal = ropeFreqs.first(); - float[] ropeFreqsImag = ropeFreqs.second(); - return new Qwen3StandardWeights( - loadQuantized(tokenEmbeddings), - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_norm.weight")), // rms_att_weight - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q.weight")), // wq - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k.weight")), // wk - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_v.weight")), // wv - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_output.weight")), // wo + // Helper methods + private FloatTensor[] loadLayerWeights(Map tensorEntries, Qwen3Configuration config, String layerName, String suffix) { + FloatTensor[] weights = new FloatTensor[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.loadQuantized(tensorEntries.get(key)); + } + return weights; + } - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_k_norm.weight")), // attnKNorm - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".attn_q_norm.weight")), // attnQNorm + private FloatArray[] loadLayerWeightsAsFloatArraysFromBuffer(Map tensorEntries, Qwen3Configuration config, String layerName, String suffix) { + FloatArray[] weights = new FloatArray[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.floatBufferToFloatArray(tensorEntries.get(key)); + } + return weights; + } - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_norm.weight")), //rms_ffn_weight - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_gate.weight")), // w1 - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_down.weight")), // w2 - loadArrayOfQuantized(config.numberOfLayers(), i -> tensorEntries.get("blk." + i + ".ffn_up.weight")), // w3 - loadQuantized(tensorEntries.get("output_norm.weight")), // rms_final_weight - new ArrayFloatTensor(ropeFreqsReal), - new ArrayFloatTensor(ropeFreqsImag), - tensorEntries.containsKey("output.weight") - ? ModelLoader.loadQuantized(tensorEntries.get("output.weight")) - : loadQuantized(tokenEmbeddings), // weights are shared - null - ); + private HalfFloatArray[] loadLayerWeightsAsHalfFloatArrays(Map tensorEntries, Qwen3Configuration config, String layerName, String suffix) { + HalfFloatArray[] weights = new HalfFloatArray[config.numberOfLayers()]; + for (int i = 0; i < config.numberOfLayers(); i++) { + String key = String.format("blk.%d.%s.%s", i, layerName, suffix); + weights[i] = ModelLoader.loadTensorAsHalfFloatArray(tensorEntries.get(key)); + } + return weights; } - // @formatter:on } diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java index 6cfdb821..6c5b0238 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlanner.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.auxiliary.Tuple2; import org.beehive.gpullama3.inference.state.Phi3State; -import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.fp16.Phi3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import uk.ac.manchester.tornado.api.GridScheduler; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java index dbdd204a..0197a655 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/Phi3TornadoVMLayerPlannerQ8_0.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.auxiliary.Tuple2; import org.beehive.gpullama3.inference.state.Phi3State; -import org.beehive.gpullama3.inference.weights.tornado.Phi3TornadoWeightsQ8_0; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Phi3TornadoWeightsQ8_0; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.phi3.Phi3Configuration; import uk.ac.manchester.tornado.api.GridScheduler; diff --git a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2Q8_0TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Q8_0TornadoVMLayerPlanner.java similarity index 99% rename from src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2Q8_0TornadoVMLayerPlanner.java rename to src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Q8_0TornadoVMLayerPlanner.java index 4884e4af..1d109a04 100644 --- a/src/main/java/org/beehive/gpullama3/inference/weights/tornado/Qwen2Q8_0TornadoVMLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2Q8_0TornadoVMLayerPlanner.java @@ -1,7 +1,8 @@ -package org.beehive.gpullama3.inference.weights.tornado; +package org.beehive.gpullama3.tornadovm; import org.beehive.gpullama3.auxiliary.Tuple2; import org.beehive.gpullama3.inference.state.Qwen2State; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen2TornadoWeightsQ8_0; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import org.beehive.gpullama3.tornadovm.Qwen2Kernels; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java index 1f9d547b..e3155afa 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen2TornadoVMLayerPlanner.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.auxiliary.Tuple2; import org.beehive.gpullama3.inference.state.Qwen2State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.fp16.Qwen2TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen2.Qwen2Configuration; import uk.ac.manchester.tornado.api.GridScheduler; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java index fd294965..4942ce2f 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3Q8_0TornadoVMLayerPlanner.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.auxiliary.Tuple2; import org.beehive.gpullama3.inference.state.Qwen3State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen3Q8_0TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Qwen3Q8_0TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import uk.ac.manchester.tornado.api.GridScheduler; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java index 57d08a90..e04e8eef 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/Qwen3TornadoVMLayerPlanner.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.auxiliary.Tuple2; import org.beehive.gpullama3.inference.state.Qwen3State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen3TornadoWeights; +import org.beehive.gpullama3.inference.weights.tornado.fp16.Qwen3TornadoWeights; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.qwen3.Qwen3Configuration; import uk.ac.manchester.tornado.api.GridScheduler; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java index 4849b847..02ccf272 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMLayerPlanner.java @@ -1,7 +1,7 @@ package org.beehive.gpullama3.tornadovm; import org.beehive.gpullama3.auxiliary.Tuple2; -import org.beehive.gpullama3.inference.weights.tornado.FP16Weights; +import org.beehive.gpullama3.inference.weights.tornado.fp16.FP16Weights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.inference.state.State; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java index 1e420b1a..8cae8eac 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMMasterPlan.java @@ -6,7 +6,6 @@ import org.beehive.gpullama3.inference.state.Qwen2State; import org.beehive.gpullama3.inference.state.Qwen3State; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.tornado.Qwen2Q8_0TornadoVMLayerPlanner; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import org.beehive.gpullama3.model.ModelType; diff --git a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMQ8_0LayerPlanner.java b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMQ8_0LayerPlanner.java index 347f3267..1173a694 100644 --- a/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMQ8_0LayerPlanner.java +++ b/src/main/java/org/beehive/gpullama3/tornadovm/TornadoVMQ8_0LayerPlanner.java @@ -2,7 +2,7 @@ import org.beehive.gpullama3.auxiliary.Tuple2; import org.beehive.gpullama3.inference.state.State; -import org.beehive.gpullama3.inference.weights.tornado.Q8_0Weights; +import org.beehive.gpullama3.inference.weights.tornado.q8_0.Q8_0Weights; import org.beehive.gpullama3.model.Configuration; import org.beehive.gpullama3.model.Model; import uk.ac.manchester.tornado.api.GridScheduler;