Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improvements to EF type mapping #50

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 0 additions & 19 deletions src/Pgvector.EntityFrameworkCore/HalfvecTypeMapping.cs

This file was deleted.

11 changes: 0 additions & 11 deletions src/Pgvector.EntityFrameworkCore/HalfvecTypeMappingSourcePlugin.cs

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@
<None Include="..\..\LICENSE.txt" Pack="true" PackagePath="\" />
<None Include="..\..\README.md" Pack="true" PackagePath="\" />
<None Include="..\..\icon.png" Pack="true" PackagePath="\" />
<None Include="build\**\*">
<Pack>True</Pack>
<PackagePath>build</PackagePath>
</None>
<ProjectReference Include="..\Pgvector\Pgvector.csproj" />
<PackageReference Include="Npgsql.EntityFrameworkCore.PostgreSQL" Version="8.0.0" />
</ItemGroup>
Expand Down
19 changes: 0 additions & 19 deletions src/Pgvector.EntityFrameworkCore/SparsevecTypeMapping.cs

This file was deleted.

This file was deleted.

26 changes: 26 additions & 0 deletions src/Pgvector.EntityFrameworkCore/VectorCodeGeneratorPlugin.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
using System;
using System.Reflection;
using Microsoft.EntityFrameworkCore;
using Microsoft.EntityFrameworkCore.Design;
using Microsoft.EntityFrameworkCore.Scaffolding;
using Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure;

namespace Pgvector.EntityFrameworkCore;

public class VectorCodeGeneratorPlugin : ProviderCodeGeneratorPlugin
{
private static readonly MethodInfo _useVectorMethodInfo
= typeof(VectorDbContextOptionsBuilderExtensions).GetMethod(
nameof(VectorDbContextOptionsBuilderExtensions.UseVector),
[typeof(NpgsqlDbContextOptionsBuilder)])!;

/// <summary>
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
/// any release. You should only use it directly in your code with extreme caution and knowing that
/// doing so can result in application failures when updating to a new Entity Framework Core release.
/// </summary>
public override MethodCallCodeFragment GenerateProviderOptions()
=> new(_useVectorMethodInfo);

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@ public void ApplyServices(IServiceCollection services)
.TryAdd<IMethodCallTranslatorPlugin, VectorDbFunctionsTranslatorPlugin>();

services.AddSingleton<IRelationalTypeMappingSourcePlugin, VectorTypeMappingSourcePlugin>();
services.AddSingleton<IRelationalTypeMappingSourcePlugin, HalfvecTypeMappingSourcePlugin>();
services.AddSingleton<IRelationalTypeMappingSourcePlugin, SparsevecTypeMappingSourcePlugin>();
}

public void Validate(IDbContextOptions options) { }
Expand Down
14 changes: 14 additions & 0 deletions src/Pgvector.EntityFrameworkCore/VectorDesignTimeServices.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using Microsoft.EntityFrameworkCore.Design;
using Microsoft.EntityFrameworkCore.Scaffolding;
using Microsoft.EntityFrameworkCore.Storage;
using Microsoft.Extensions.DependencyInjection;

namespace Pgvector.EntityFrameworkCore;

public class VectorDesignTimeServices : IDesignTimeServices
{
public virtual void ConfigureDesignTimeServices(IServiceCollection serviceCollection)
=> serviceCollection
.AddSingleton<IRelationalTypeMappingSourcePlugin, VectorTypeMappingSourcePlugin>()
.AddSingleton<IProviderCodeGeneratorPlugin, VectorCodeGeneratorPlugin>();
}
15 changes: 11 additions & 4 deletions src/Pgvector.EntityFrameworkCore/VectorTypeMapping.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,18 @@ namespace Pgvector.EntityFrameworkCore;

public class VectorTypeMapping : RelationalTypeMapping
{
public static VectorTypeMapping Default { get; } = new();
public static VectorTypeMapping Default { get; } = new("vector", typeof(Vector));

public VectorTypeMapping() : base("vector", typeof(Vector)) { }

public VectorTypeMapping(string storeType) : base(storeType, typeof(Vector)) { }
public VectorTypeMapping(string storeType, Type clrType, int? size = null)
: this(
new RelationalTypeMappingParameters(
new CoreTypeMappingParameters(clrType),
storeType,
StoreTypePostfix.Size,
size: size,
fixedLength: size is not null))
{
}

protected VectorTypeMapping(RelationalTypeMappingParameters parameters) : base(parameters) { }

Expand Down
29 changes: 26 additions & 3 deletions src/Pgvector.EntityFrameworkCore/VectorTypeMappingSourcePlugin.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,30 @@ namespace Pgvector.EntityFrameworkCore;
public class VectorTypeMappingSourcePlugin : IRelationalTypeMappingSourcePlugin
{
public RelationalTypeMapping? FindMapping(in RelationalTypeMappingInfo mappingInfo)
=> mappingInfo.ClrType == typeof(Vector)
? new VectorTypeMapping(mappingInfo.StoreTypeName ?? "vector")
: null;
{
if (mappingInfo.StoreTypeName is not null)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new logic for this method seems a bit complex. Thoughts on checking the ClrType first, then falling back to StoreTypeName?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue is that users may provide both a store type and a CLR type - this can be done by annotating a Vector property in the model with a store type (e.g. vector(3)). In that case we want to go into the top piece of logic, looking at the user-provided string (which may include the dimension), and not just look at the CLR type.

So there's basically three possibilities here:

  • Only a store type is provided (ClrType is null) - this is scaffolding. We go into the code above, simple.
  • Only a CLR type is provided (StoreType is null) - this is the regular case of a Vector property on the user's type, which isn't otherwise configured for the column type. We go into the code below, also simple.
  • Both StoreType and ClrType are provided. We go into the code above, do the same as scaffolding, except that we also validate that the user's CLR type is compatible with their requested store type (so they don't get it wrong and e.g. put [Column(TypeName = "sparsevec")] on a Vector property by accident.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the third case, I think it's better to use the CLR type (at least for now) for backwards compatibility.

Also, the error message when there is a mismatch is a bit confusing.

[Column("half_embedding", TypeName = "vector(3)")]
public HalfVector? HalfEmbedding { get; set; }

outputs

The property 'Item.HalfEmbedding' could not be mapped because it is of type 'HalfVector', which is not a supported primitive type or a valid entity type.

(however, HalfVector is a valid type)

Copy link
Contributor Author

@roji roji Mar 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the third case, I think it's better to use the CLR type (at least for now) for backwards compatibility.

Are you saying that if a user specifies an incompatible combination, e.g. a Vector (or even string) CLR type with halfvec as the store type, we should just ignore and return an incorrect type mapping? That is bound to fail later anyway (when you try to actually read/write the thing), only with a more cryptic message - so I don't think there should be a backwards compatibility concern... Or am I missing something?

FWIW this behavior is what's generally done in EF provider and other extensions.

Also, the error message when there is a mismatch is a bit confusing

Yeah, that's something that should be improved on the EF side - can you open an issue on the EF repo?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The tests currently pass even with a type mismatch - diff and CI logs (typecasting must be happening somewhere).
  2. Will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests currently pass even with a type mismatch - diff and CI logs (typecasting must be happening somewhere).

That's odd - is there any sort of implicit casting between the vector types? Even if so, try using string as the CLR type or similar.

In any case, I'd consider this a bug - the user is specifying an invalid combination of store and CLR types. But if you really prefer, I can remove the compatibility check for the CLR type, and make the store type be the only thing that matters once it's specified. Let me know.

{
VectorTypeMapping? mapping = (mappingInfo.StoreTypeNameBase ?? mappingInfo.StoreTypeName) switch
{
"vector" => new(mappingInfo.StoreTypeName, typeof(Vector), mappingInfo.Size),
"halfvec" => new(mappingInfo.StoreTypeName, typeof(HalfVector), mappingInfo.Size),
"sparsevec" => new(mappingInfo.StoreTypeName, typeof(SparseVector), mappingInfo.Size),
_ => null,
};

// If the caller hasn't specified a CLR type (this is scaffolding), or if the user has specified
// the one matching the store type, return the mapping.
return mappingInfo.ClrType is null || mappingInfo.ClrType == mapping?.ClrType
? mapping : null;
}

// No store type specified, look up by the CLR type only
return mappingInfo.ClrType switch
{
var t when t == typeof(Vector) => new VectorTypeMapping("vector", typeof(Vector), mappingInfo.Size),
var t when t == typeof(HalfVector) => new VectorTypeMapping("halfvec", typeof(HalfVector), mappingInfo.Size),
var t when t == typeof(SparseVector) => new VectorTypeMapping("sparsevec", typeof(SparseVector), mappingInfo.Size),
_ => null,
};
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
<Project>
<PropertyGroup>
<MSBuildAllProjects>$(MSBuildAllProjects);$(MSBuildThisFileFullPath)</MSBuildAllProjects>
<EFCoreNpgsqlPgvectorFile>$(IntermediateOutputPath)EFCoreNpgsqlPgvector$(DefaultLanguageSourceExtension)</EFCoreNpgsqlPgvectorFile>
</PropertyGroup>
<Choose>
<When Condition="'$(Language)' == 'F#'">
<Choose>
<When Condition="'$(OutputType)' == 'Exe' OR '$(OutputType)' == 'WinExe'">
<PropertyGroup>
<CodeFragmentItemGroup>CompileBefore</CodeFragmentItemGroup>
</PropertyGroup>
</When>
<Otherwise>
<PropertyGroup>
<CodeFragmentItemGroup>CompileAfter</CodeFragmentItemGroup>
</PropertyGroup>
</Otherwise>
</Choose>
</When>
<Otherwise>
<PropertyGroup>
<CodeFragmentItemGroup>Compile</CodeFragmentItemGroup>
</PropertyGroup>
</Otherwise>
</Choose>
<Target Name="AddEFCoreNpgsqlPgvector"
BeforeTargets="CoreCompile"
DependsOnTargets="PrepareForBuild"
Condition="'$(DesignTimeBuild)' != 'True'"
Inputs="$(MSBuildAllProjects)"
Outputs="$(EFCoreNpgsqlPgvectorFile)">
<ItemGroup>
<EFCoreNpgsqlPgvectorServices Include="Microsoft.EntityFrameworkCore.Design.DesignTimeServicesReferenceAttribute">
<_Parameter1>Pgvector.EntityFrameworkCore.VectorDesignTimeServices, Pgvector.EntityFrameworkCore</_Parameter1>
<_Parameter2>Npgsql.EntityFrameworkCore.PostgreSQL</_Parameter2>
</EFCoreNpgsqlPgvectorServices>
</ItemGroup>
<WriteCodeFragment AssemblyAttributes="@(EFCoreNpgsqlPgvectorServices)"
Language="$(Language)"
OutputFile="$(EFCoreNpgsqlPgvectorFile)">
<Output TaskParameter="OutputFile" ItemName="$(CodeFragmentItemGroup)" />
<Output TaskParameter="OutputFile" ItemName="FileWrites" />
</WriteCodeFragment>
</Target>
</Project>
74 changes: 55 additions & 19 deletions tests/Pgvector.CSharp.Tests/EntityFrameworkCoreTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -66,79 +66,115 @@ public async Task Main()

var embedding = new Vector(new float[] { 1, 1, 1 });
var items = await ctx.Items.FromSql($"SELECT * FROM efcore_items ORDER BY embedding <-> {embedding} LIMIT 5").ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
Assert.Equal(new Half[] { (Half)1, (Half)1, (Half)1 }, items[0].HalfEmbedding!.ToArray());
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray());
Assert.Equal([(Half)1, (Half)1, (Half)1], items[0].HalfEmbedding!.ToArray());
Assert.Equal(new BitArray(new bool[] { false, false, false }), items[0].BinaryEmbedding!);
Assert.Equal(new float[] { 1, 1, 1 }, items[0].SparseEmbedding!.ToArray());
Assert.Equal([1, 1, 1], items[0].SparseEmbedding!.ToArray());

// vector distance functions

items = await ctx.Items.OrderBy(x => x.Embedding!.L2Distance(embedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal(new float[] { 1, 1, 1 }, items[0].Embedding!.ToArray());
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());
Assert.Equal([1, 1, 1], items[0].Embedding!.ToArray());

items = await ctx.Items.OrderBy(x => x.Embedding!.MaxInnerProduct(embedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.Embedding!.CosineDistance(embedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);

items = await ctx.Items.OrderBy(x => x.Embedding!.L1Distance(embedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());

// halfvec distance functions

var halfEmbedding = new HalfVector(new Half[] { (Half)1, (Half)1, (Half)1 });
items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L2Distance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.MaxInnerProduct(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.CosineDistance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);

items = await ctx.Items.OrderBy(x => x.HalfEmbedding!.L1Distance(halfEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());

// sparsevec distance functions

var sparseEmbedding = new SparseVector(new float[] { 1, 1, 1 });
items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L2Distance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.MaxInnerProduct(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.CosineDistance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(3, items[2].Id);

items = await ctx.Items.OrderBy(x => x.SparseEmbedding!.L1Distance(sparseEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, items.Select(v => v.Id).ToArray());
Assert.Equal([1, 3, 2], items.Select(v => v.Id).ToArray());

// bit distance functions

var binaryEmbedding = new BitArray(new bool[] { true, false, true });
items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.HammingDistance(binaryEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());

items = await ctx.Items.OrderBy(x => x.BinaryEmbedding!.JaccardDistance(binaryEmbedding)).Take(5).ToListAsync();
Assert.Equal(new int[] { 2, 3, 1 }, items.Select(v => v.Id).ToArray());
Assert.Equal([2, 3, 1], items.Select(v => v.Id).ToArray());

// additional

items = await ctx.Items
.OrderBy(x => x.Id)
.Where(x => x.Embedding!.L2Distance(embedding) < 1.5)
.ToListAsync();
Assert.Equal(new int[] { 1, 3 }, items.Select(v => v.Id).ToArray());
Assert.Equal([1, 3], items.Select(v => v.Id).ToArray());

var neighbors = await ctx.Items
.OrderBy(x => x.Embedding!.L2Distance(embedding))
.Select(x => new { Entity = x, Distance = x.Embedding!.L2Distance(embedding) })
.ToListAsync();
Assert.Equal(new int[] { 1, 3, 2 }, neighbors.Select(v => v.Entity.Id).ToArray());
Assert.Equal(new double[] { 0, 1, Math.Sqrt(3) }, neighbors.Select(v => v.Distance).ToArray());
Assert.Equal([1, 3, 2], neighbors.Select(v => v.Entity.Id).ToArray());
Assert.Equal([0, 1, Math.Sqrt(3)], neighbors.Select(v => v.Distance).ToArray());
}

[Theory]
[InlineData(typeof(Vector), null, "vector")]
[InlineData(typeof(Vector), 3, "vector(3)")]
[InlineData(typeof(HalfVector), null, "halfvec")]
[InlineData(typeof(HalfVector), 3, "halfvec(3)")]
[InlineData(typeof(SparseVector), null, "sparsevec")]
[InlineData(typeof(SparseVector), 3, "sparsevec(3)")]
public void By_StoreType(Type type, int? size, string expectedStoreType)
{
using var ctx = new ItemContext();
var typeMappingSource = ctx.GetService<IRelationalTypeMappingSource>();

var typeMapping = typeMappingSource.FindMapping(type, storeTypeName: null, size: size)!;
Assert.Equal(expectedStoreType, typeMapping.StoreType);
Assert.Same(type, typeMapping.ClrType);
Assert.Equal(size, typeMapping.Size);
}

[Theory]
[InlineData("vector", typeof(Vector), null)]
[InlineData("vector(3)", typeof(Vector), 3)]
[InlineData("halfvec", typeof(HalfVector), null)]
[InlineData("halfvec(3)", typeof(HalfVector), 3)]
[InlineData("sparsevec", typeof(SparseVector), null)]
[InlineData("sparsevec(3)", typeof(SparseVector), 3)]
public void By_ClrType(string storeType, Type expectedType, int? expectedSize)
{
using var ctx = new ItemContext();
var typeMappingSource = ctx.GetService<IRelationalTypeMappingSource>();

var typeMapping = typeMappingSource.FindMapping(storeType)!;
Assert.Equal(storeType, typeMapping.StoreType);
Assert.Same(expectedType, typeMapping.ClrType);
Assert.Equal(expectedSize, typeMapping.Size);
}
}