Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/main/java/org/beehive/gpullama3/LlamaApp.java
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
package org.beehive.gpullama3;

import org.beehive.gpullama3.aot.AOT;
import org.beehive.gpullama3.auxiliary.LastRunMetrics;
import org.beehive.gpullama3.inference.sampler.Sampler;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.model.loader.ModelLoader;

import java.io.IOException;

Expand Down
85 changes: 0 additions & 85 deletions src/main/java/org/beehive/gpullama3/aot/AOT.java

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.fp16;

import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.HalfFloatArray;

Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.fp16;

import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.fp16;

import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.fp16;

import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.fp16;

import org.beehive.gpullama3.core.model.GGMLType;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.q8_0;

import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.q8_0;

import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
import org.beehive.gpullama3.inference.weights.tornado.TornadoWeights;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;

public class Q8_0Weights implements TornadoWeights {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.q8_0;

import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package org.beehive.gpullama3.inference.weights.tornado;
package org.beehive.gpullama3.inference.weights.tornado.q8_0;

import org.beehive.gpullama3.core.model.GGMLType;
import org.beehive.gpullama3.core.model.tensor.Q8_0QuantizedTensor;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
package org.beehive.gpullama3.model.loader;

import org.beehive.gpullama3.core.model.GGUF;
import org.beehive.gpullama3.core.model.tensor.GGMLTensorEntry;
import org.beehive.gpullama3.core.types.Pair;
import org.beehive.gpullama3.inference.weights.Weights;
import org.beehive.gpullama3.model.Configuration;
import org.beehive.gpullama3.model.Model;
import org.beehive.gpullama3.tokenizer.impl.Tokenizer;
import org.beehive.gpullama3.tokenizer.vocabulary.Vocabulary;

import java.io.IOException;
import java.nio.channels.FileChannel;
import java.util.Map;

/**
* Abstract base class for model loaders using Template Method pattern. Provides common loading flow with extension points for model-specific logic.
*
* @param <M>
* The specific Model type to load
* @param <C>
* The specific Configuration type for the model
*/
public abstract class AbstractModelLoader<M extends Model, C extends Configuration> {

protected final FileChannel fileChannel;
protected final GGUF gguf;
protected final int contextLength;
protected final boolean loadWeights;
protected final boolean useTornadovm;

protected Vocabulary vocabulary;

protected AbstractModelLoader(FileChannel fileChannel, GGUF gguf, int contextLength, boolean loadWeights, boolean useTornadovm) {
this.fileChannel = fileChannel;
this.gguf = gguf;
this.contextLength = contextLength;
this.loadWeights = loadWeights;
this.useTornadovm = useTornadovm;
}

/**
* Template method that defines the model loading workflow. Subclasses should not override this method.
*
* @return The loaded model instance
*/
public final M loadModel() {
try {
Map<String, Object> metadata = gguf.getMetadata();

// Step 1: Load vocabulary
this.vocabulary = loadVocabulary(metadata);

// Step 2: Create tokenizer
Tokenizer tokenizer = createTokenizer(metadata, vocabulary);

// Step 3: Create configuration
C config = createConfiguration(metadata);

// Step 4: Load weights (if requested)
Weights weights = null;
if (loadWeights) {
Map<String, GGMLTensorEntry> tensorEntries = GGUF.loadTensors(fileChannel, gguf.getTensorDataOffset(), gguf.getTensorInfos());
weights = loadWeights(tensorEntries, config);
}

// Step 5: Create and return model instance
return createModel(config, tokenizer, weights);

} catch (IOException e) {
throw new ModelLoadException("Failed to load model", e);
}
}

/**
* Load the vocabulary from GGUF metadata. Model-specific implementations should override this method.
*
* @param metadata
* The GGUF metadata map
* @return The loaded Vocabulary
*/
protected abstract Vocabulary loadVocabulary(Map<String, Object> metadata);

/**
* Create a tokenizer instance for this model.
*
* @param metadata
* The GGUF metadata map
* @param vocabulary
* The loaded vocabulary
* @return The tokenizer instance
*/
protected abstract Tokenizer createTokenizer(Map<String, Object> metadata, Vocabulary vocabulary);

/**
* Create a configuration instance from GGUF metadata.
*
* @param metadata
* The GGUF metadata map
* @return The configuration instance
*/
protected abstract C createConfiguration(Map<String, Object> metadata);

/**
* Load model weights from tensor entries. Default implementation handles common weight loading logic.
*
* @param tensorEntries
* Map of tensor names to tensor entries
* @param config
* The model configuration
* @return The loaded weights
*/
public Weights loadWeights(Map<String, GGMLTensorEntry> tensorEntries, C config) {
// Precompute RoPE frequencies
Pair<float[], float[]> ropeFreqs = precomputeRopeFrequencies(config);

// Get token embeddings and output weights
GGMLTensorEntry tokenEmbeddings = getTokenEmbeddings(tensorEntries);
GGMLTensorEntry outputWeight = getOutputWeight(tensorEntries, tokenEmbeddings);

// Delegate to specific implementation
if (useTornadovm) {
return createTornadoVMWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
} else {
return createStandardWeights(tensorEntries, config, ropeFreqs, tokenEmbeddings, outputWeight);
}
}

/**
* Create the final model instance.
*
* @param config
* The model configuration
* @param tokenizer
* The tokenizer
* @param weights
* The loaded weights
* @return The model instance
*/
protected abstract M createModel(C config, Tokenizer tokenizer, Weights weights);

/**
* Precompute RoPE frequencies for this model. Default implementation can be overridden for custom RoPE configurations.
*/
protected abstract Pair<float[], float[]> precomputeRopeFrequencies(C config);

/**
* Get token embeddings tensor entry. Default implementation can be overridden for different tensor naming.
*/
protected GGMLTensorEntry getTokenEmbeddings(Map<String, GGMLTensorEntry> tensorEntries) {
return tensorEntries.get("token_embd.weight");
}

/**
* Get output weight tensor entry. Default implementation falls back to token embeddings if output.weight not found.
*/
protected GGMLTensorEntry getOutputWeight(Map<String, GGMLTensorEntry> tensorEntries, GGMLTensorEntry tokenEmbeddings) {
return tensorEntries.getOrDefault("output.weight", tokenEmbeddings);
}

/**
* Create standard (CPU) weights.
*/
protected abstract Weights createStandardWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight);

/**
* Create TornadoVM (GPU) weights.
*/
protected abstract Weights createTornadoVMWeights(Map<String, GGMLTensorEntry> tensorEntries, C config, Pair<float[], float[]> ropeFreqs, GGMLTensorEntry tokenEmbeddings,
GGMLTensorEntry outputWeight);
}
Loading