Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 34 additions & 37 deletions LLama.Unittest/LLamaEmbedderTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -41,43 +41,40 @@ private async Task CompareEmbeddings(string modelPath)

var spoon = (await embedder.GetEmbeddings("The spoon is not real")).Single().EuclideanNormalization();
Assert.DoesNotContain(float.NaN, spoon);

if (false)
{
//TODO: the below does not work with the new memory efficient context handling - we probably need to define Microsoft.Extensions.AI.IEmbeddingGenerator GetService interface that creates the context on the fly

var generator = (IEmbeddingGenerator<string, Embedding<float>>)embedder;
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>());
Assert.Equal(nameof(LLamaEmbedder), generator.GetService<EmbeddingGeneratorMetadata>()?.ProviderName);
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>()?.DefaultModelId);
Assert.NotEmpty(generator.GetService<EmbeddingGeneratorMetadata>()?.DefaultModelId!);
Assert.Same(embedder, generator.GetService<LLamaEmbedder>());
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Null(generator.GetService<string>());

var embeddings = await generator.GenerateAsync(
[
"The cat is cute",
"The kitten is cute",
"The spoon is not real"
]);
Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");

var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);

_testOutputHelper.WriteLine("");
_testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}");
_testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}");

Assert.True(close < far);
}

using var context = new LLamaContext(weights, @params);
var managedEmbedder = new LLamaEmbedder(context);
IEmbeddingGenerator<string, Embedding<float>> generator = managedEmbedder;
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>());
Assert.Equal(nameof(LLamaEmbedder), generator.GetService<EmbeddingGeneratorMetadata>()?.ProviderName);
Assert.NotNull(generator.GetService<EmbeddingGeneratorMetadata>()?.DefaultModelId);
Assert.NotEmpty(generator.GetService<EmbeddingGeneratorMetadata>()?.DefaultModelId!);
Assert.Same(managedEmbedder, generator.GetService<LLamaEmbedder>());
Assert.Same(generator, generator.GetService<IEmbeddingGenerator<string, Embedding<float>>>());
Assert.Null(generator.GetService<string>());

var embeddings = await generator.GenerateAsync(
[
"The cat is cute",
"The kitten is cute",
"The spoon is not real"
]);
Assert.All(cat.Zip(embeddings[0].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(kitten.Zip(embeddings[1].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));
Assert.All(spoon.Zip(embeddings[2].Vector.Span.EuclideanNormalization()), e => Assert.Equal(e.First, e.Second, 0.001));

_testOutputHelper.WriteLine($"Cat = [{string.Join(",", cat.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Kitten = [{string.Join(",", kitten.AsMemory().Slice(0, 7).ToArray())}...]");
_testOutputHelper.WriteLine($"Spoon = [{string.Join(",", spoon.AsMemory().Slice(0, 7).ToArray())}...]");

var close = 1 - Dot(cat, kitten);
var far = 1 - Dot(cat, spoon);

_testOutputHelper.WriteLine("");
_testOutputHelper.WriteLine($"Cat.Kitten (Close): {close:F4}");
_testOutputHelper.WriteLine($"Cat.Spoon (Far): {far:F4}");

Assert.True(close < far);
}

[Fact]
Expand Down
79 changes: 79 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,84 @@ public LLamaToken[] Tokenize(string text, bool addBos = true, bool special = fal
return NativeHandle.Tokenize(text, addBos, special, Encoding);
}

#region Sequence ID management
private LLamaSeqIdManager? _seqIdManager;

/// <summary>
/// Get the sequence ID manager for this context.
/// </summary>
public LLamaSeqIdManager SequenceManager
{
get
{
var manager = _seqIdManager;
if (manager != null) return manager;
var newManager = new LLamaSeqIdManager(Params.SeqMax);
var original = Interlocked.CompareExchange(ref _seqIdManager, newManager, comparand: null);
manager = original ?? newManager;
return manager;
}
}

/// <summary>
/// Returns the next available sequence ID for use in model operations.
/// Callers will asynchronously wait if none are available.
/// On disposal, the sequence ID is returned to the owning <see cref="LLamaContext"/> for reuse.
/// </summary>
/// <remarks>
/// Failure to dispose the returned <see cref="ManagedLLamaSeqId"/> will likely result in undefined behavior.
/// </remarks>
/// <remarks>
/// The returned sequence represents an exclusive reservation on the sequence ID within the context.
/// For the duration of the <see cref="ManagedLLamaSeqId"/>, no other caller will receive the same sequence ID from this context.
/// </remarks>
/// <param name="removeMemoryOnRelease">flag indicating whether to remove memory associated with the sequence ID when it is released back to the manager.</param>
/// <param name="timeout">optional timeout for acquiring a sequence ID. If null, waits indefinitely.</param>
/// <param name="cancellationToken">cancellation token to cancel the wait operation.</param>
/// <returns>The next available sequence ID.</returns>
public async Task<ManagedLLamaSeqId> AcquireSequenceIdAsync(bool removeMemoryOnRelease = false, TimeSpan? timeout = null, CancellationToken cancellationToken = default)
{
var seqId = await SequenceManager.NextAsync(timeout, cancellationToken).ConfigureAwait(false);
return new ManagedLLamaSeqId(owner: this, seqId, removeMemoryOnRelease);
}

/// <summary>
/// Represents a managed <see cref="SeqId"/> that is returned to the owning <see cref="LLamaContext"/> when disposed.
/// </summary>
public readonly struct ManagedLLamaSeqId : IDisposable
{
private readonly LLamaContext? _owner;
private readonly bool _removeMemoryOnRelease;

/// <summary>
/// The sequence ID.
/// </summary>
public LLamaSeqId SeqId { get; }

/// <summary>
/// Implicit conversion to <see cref="LLamaSeqId"/>.
/// </summary>
/// <param name="managedSeqId">managed sequence ID.</param>
/// <returns>the underlying sequence ID.</returns>
public static implicit operator LLamaSeqId(ManagedLLamaSeqId managedSeqId) => managedSeqId.SeqId;

internal ManagedLLamaSeqId(LLamaContext owner, LLamaSeqId seqId, bool removeMemoryOnRelease)
{
_owner = owner;
SeqId = seqId;
_removeMemoryOnRelease = removeMemoryOnRelease;
}

/// <inheritdoc />
public void Dispose()
{
if (_owner == null || _owner.NativeHandle.IsClosed) return;
if (_removeMemoryOnRelease) _owner.NativeHandle.MemorySequenceRemove(SeqId, 0, -1);
_owner.SequenceManager.Return(SeqId);
}
}
#endregion

/// <summary>
/// Detokenize the tokens to text.
/// </summary>
Expand Down Expand Up @@ -441,6 +519,7 @@ public Task<DecodeResult> DecodeAsync(LLamaBatchEmbeddings batch, CancellationTo
public void Dispose()
{
NativeHandle.Dispose();
_seqIdManager?.Dispose();
}

/// <summary>
Expand Down
42 changes: 19 additions & 23 deletions LLama/LLamaEmbedder.EmbeddingGenerator.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
using System.Diagnostics;
using System.Threading;
using System.Threading.Tasks;
using LLama.Native;
using Microsoft.Extensions.AI;

namespace LLama;
Expand All @@ -16,25 +15,27 @@ public partial class LLamaEmbedder
/// <inheritdoc />
object? IEmbeddingGenerator.GetService(Type serviceType, object? serviceKey)
{
if (serviceKey is null)
if (serviceKey is not null)
{
if (serviceType == typeof(EmbeddingGeneratorMetadata))
{
return _metadata ??= new(
nameof(LLamaEmbedder),
defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
defaultModelDimensions: EmbeddingSize);
}
return null;
}

if (_hasExternalContext && serviceType == typeof(EmbeddingGeneratorMetadata))
{
return _metadata ??= new(
nameof(LLamaEmbedder),
defaultModelId: Context.NativeHandle.ModelHandle.ReadMetadata().TryGetValue("general.name", out var name) ? name : null,
defaultModelDimensions: EmbeddingSize);
}

if (serviceType?.IsInstanceOfType(Context) is true)
{
return Context;
}
if (_hasExternalContext && serviceType?.IsInstanceOfType(Context) is true)
{
return Context;
}

if (serviceType?.IsInstanceOfType(this) is true)
{
return this;
}
if (serviceType?.IsInstanceOfType(this) is true)
{
return this;
}

return null;
Expand All @@ -43,11 +44,6 @@ public partial class LLamaEmbedder
/// <inheritdoc />
async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Embedding<float>>.GenerateAsync(IEnumerable<string> values, EmbeddingGenerationOptions? options, CancellationToken cancellationToken)
{
if (Context.NativeHandle.PoolingType == LLamaPoolingType.None)
{
throw new NotSupportedException($"Embedding generation is not supported with {nameof(LLamaPoolingType)}.{nameof(LLamaPoolingType.None)}.");
}

GeneratedEmbeddings<Embedding<float>> results = new()
{
Usage = new() { InputTokenCount = 0 },
Expand All @@ -56,7 +52,7 @@ async Task<GeneratedEmbeddings<Embedding<float>>> IEmbeddingGenerator<string, Em
foreach (var value in values)
{
var (embeddings, tokenCount) = await GetEmbeddingsWithTokenCount(value, cancellationToken).ConfigureAwait(false);
Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding when pooling is enabled.");
Debug.Assert(embeddings.Count == 1, "Should be one and only one embedding returned from LLama for a single input string.");

results.Usage.InputTokenCount += tokenCount;
results.Add(new Embedding<float>(embeddings[0]) { CreatedAt = DateTime.UtcNow });
Expand Down
Loading